diff options
author | Brendan Abolivier <babolivier@matrix.org> | 2020-06-10 11:42:30 +0100 |
---|---|---|
committer | Brendan Abolivier <babolivier@matrix.org> | 2020-06-10 11:42:30 +0100 |
commit | ec0a7b9034806d6b2ba086bae58f5c6b0fd14672 (patch) | |
tree | f2af547b1342795e10548f8fb7a9cfc93e03df37 | |
parent | changelog (diff) | |
parent | 1.15.0rc1 (diff) | |
download | synapse-ec0a7b9034806d6b2ba086bae58f5c6b0fd14672.tar.xz |
Merge branch 'develop' into babolivier/mark_unread
963 files changed, 65965 insertions, 30448 deletions
diff --git a/.buildkite/docker-compose.py35.pg95.yaml b/.buildkite/docker-compose.py35.pg95.yaml deleted file mode 100644 index 43237b7775..0000000000 --- a/.buildkite/docker-compose.py35.pg95.yaml +++ /dev/null @@ -1,22 +0,0 @@ -version: '3.1' - -services: - - postgres: - image: postgres:9.5 - environment: - POSTGRES_PASSWORD: postgres - command: -c fsync=off - - testenv: - image: python:3.5 - depends_on: - - postgres - env_file: .env - environment: - SYNAPSE_POSTGRES_HOST: postgres - SYNAPSE_POSTGRES_USER: postgres - SYNAPSE_POSTGRES_PASSWORD: postgres - working_dir: /src - volumes: - - ..:/src diff --git a/.buildkite/docker-compose.py37.pg11.yaml b/.buildkite/docker-compose.py37.pg11.yaml deleted file mode 100644 index b767228147..0000000000 --- a/.buildkite/docker-compose.py37.pg11.yaml +++ /dev/null @@ -1,22 +0,0 @@ -version: '3.1' - -services: - - postgres: - image: postgres:11 - environment: - POSTGRES_PASSWORD: postgres - command: -c fsync=off - - testenv: - image: python:3.7 - depends_on: - - postgres - env_file: .env - environment: - SYNAPSE_POSTGRES_HOST: postgres - SYNAPSE_POSTGRES_USER: postgres - SYNAPSE_POSTGRES_PASSWORD: postgres - working_dir: /src - volumes: - - ..:/src diff --git a/.buildkite/docker-compose.py37.pg95.yaml b/.buildkite/docker-compose.py37.pg95.yaml deleted file mode 100644 index 02fcd28304..0000000000 --- a/.buildkite/docker-compose.py37.pg95.yaml +++ /dev/null @@ -1,22 +0,0 @@ -version: '3.1' - -services: - - postgres: - image: postgres:9.5 - environment: - POSTGRES_PASSWORD: postgres - command: -c fsync=off - - testenv: - image: python:3.7 - depends_on: - - postgres - env_file: .env - environment: - SYNAPSE_POSTGRES_HOST: postgres - SYNAPSE_POSTGRES_USER: postgres - SYNAPSE_POSTGRES_PASSWORD: postgres - working_dir: /src - volumes: - - ..:/src diff --git a/.buildkite/format_tap.py b/.buildkite/format_tap.py deleted file mode 100644 index b557a9c38e..0000000000 --- a/.buildkite/format_tap.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from tap.parser import Parser -from tap.line import Result, Unknown, Diagnostic - -out = ["### TAP Output for " + sys.argv[2]] - -p = Parser() - -in_error = False - -for line in p.parse_file(sys.argv[1]): - if isinstance(line, Result): - if in_error: - out.append("") - out.append("</pre></code></details>") - out.append("") - out.append("----") - out.append("") - in_error = False - - if not line.ok and not line.todo: - in_error = True - - out.append("FAILURE Test #%d: ``%s``" % (line.number, line.description)) - out.append("") - out.append("<details><summary>Show log</summary><code><pre>") - - elif isinstance(line, Diagnostic) and in_error: - out.append(line.text) - -if out: - for line in out[:-3]: - print(line) diff --git a/.buildkite/merge_base_branch.sh b/.buildkite/merge_base_branch.sh index eb7219a56d..361440fd1a 100755 --- a/.buildkite/merge_base_branch.sh +++ b/.buildkite/merge_base_branch.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -set -ex +set -e if [[ "$BUILDKITE_BRANCH" =~ ^(develop|master|dinsic|shhs|release-.*)$ ]]; then echo "Not merging forward, as this is a release branch" @@ -18,6 +18,8 @@ else GITBASE=$BUILDKITE_PULL_REQUEST_BASE_BRANCH fi +echo "--- merge_base_branch $GITBASE" + # Show what we are before git --no-pager show -s diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml new file mode 100644 index 0000000000..2acbe66f4c --- /dev/null +++ b/.buildkite/postgres-config.yaml @@ -0,0 +1,21 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the script to connect to the postgresql database that will be available in the +# CI's Docker setup at the point where this file is considered. +server_name: "localhost:8800" + +signing_key_path: "/src/.buildkite/test.signing.key" + +report_stats: false + +database: + name: "psycopg2" + args: + user: postgres + host: postgres + password: postgres + database: synapse + +# Suppress the key server warning. +trusted_key_servers: + - server_name: "matrix.org" +suppress_key_server_warning: true diff --git a/.buildkite/scripts/create_postgres_db.py b/.buildkite/scripts/create_postgres_db.py new file mode 100755 index 0000000000..df6082b0ac --- /dev/null +++ b/.buildkite/scripts/create_postgres_db.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from synapse.storage.engines import create_engine + +logger = logging.getLogger("create_postgres_db") + +if __name__ == "__main__": + # Create a PostgresEngine. + db_engine = create_engine({"name": "psycopg2", "args": {}}) + + # Connect to postgres to create the base database. + # We use "postgres" as a database because it's bound to exist and the "synapse" one + # doesn't exist yet. + db_conn = db_engine.module.connect( + user="postgres", host="postgres", password="postgres", dbname="postgres" + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("CREATE DATABASE synapse;") + cur.close() + db_conn.close() diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh new file mode 100755 index 0000000000..cdb77b556c --- /dev/null +++ b/.buildkite/scripts/test_old_deps.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# this script is run by buildkite in a plain `xenial` container; it installs the +# minimal requirements for tox and hands over to the py35-old tox environment. + +set -ex + +apt-get update +apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox + +export LANG="C.UTF-8" + +exec tox -e py35-old,combine diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.buildkite/scripts/test_synapse_port_db.sh new file mode 100755 index 0000000000..9ed2177635 --- /dev/null +++ b/.buildkite/scripts/test_synapse_port_db.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# +# Test script for 'synapse_port_db', which creates a virtualenv, installs Synapse along +# with additional dependencies needed for the test (such as coverage or the PostgreSQL +# driver), update the schema of the test SQLite database and run background updates on it, +# create an empty test database in PostgreSQL, then run the 'synapse_port_db' script to +# test porting the SQLite database to the PostgreSQL database (with coverage). + +set -xe +cd `dirname $0`/../.. + +echo "--- Install dependencies" + +# Install dependencies for this test. +pip install psycopg2 coverage coverage-enable-subprocess + +# Install Synapse itself. This won't update any libraries. +pip install -e . + +echo "--- Generate the signing key" + +# Generate the server's signing key. +python -m synapse.app.homeserver --generate-keys -c .buildkite/sqlite-config.yaml + +echo "--- Prepare the databases" + +# Make sure the SQLite3 database is using the latest schema and has no pending background update. +scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml + +# Create the PostgreSQL database. +./.buildkite/scripts/create_postgres_db.py + +echo "+++ Run synapse_port_db" + +# Run the script +coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml new file mode 100644 index 0000000000..6d9bf80d84 --- /dev/null +++ b/.buildkite/sqlite-config.yaml @@ -0,0 +1,18 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the 'update_database' script to connect to the test SQLite database to upgrade its +# schema and run background updates on it. +server_name: "localhost:8800" + +signing_key_path: "/src/.buildkite/test.signing.key" + +report_stats: false + +database: + name: "sqlite3" + args: + database: ".buildkite/test_db.db" + +# Suppress the key server warning. +trusted_key_servers: + - server_name: "matrix.org" +suppress_key_server_warning: true diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db new file mode 100644 index 0000000000..f20567ba73 --- /dev/null +++ b/.buildkite/test_db.db Binary files differdiff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist index cda5c84e94..fd98cbbaf6 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist @@ -5,8 +5,6 @@ Message history can be paginated Can re-join room if re-invited -/upgrade creates a new room - The only membership state included in an initial sync is for all the senders in the timeline Local device key changes get to remote servers @@ -28,3 +26,16 @@ User sees updates to presence from other users in the incremental sync. Gapped incremental syncs include all state changes Old members are included in gappy incr LL sync if they start speaking + +# new failures as of https://github.com/matrix-org/sytest/pull/732 +Device list doesn't change if remote server is down +Remote servers cannot set power levels in rooms without existing powerlevels +Remote servers should reject attempts by non-creators to set the power levels + +# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 +Server correctly handles incoming m.device_list_update + +# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536 +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state + +Can get rooms/{roomId}/members at a given point diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000000..1632170c9d --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,5 @@ +**If you are looking for support** please ask in **#synapse:matrix.org** +(using a matrix.org account if necessary). We do not use GitHub issues for +support. + +**If you want to report a security issue** please see https://matrix.org/security-disclosure-policy/ diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.md b/.github/ISSUE_TEMPLATE/BUG_REPORT.md index 5cf844bfb1..75c9b2c9fe 100644 --- a/.github/ISSUE_TEMPLATE/BUG_REPORT.md +++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.md @@ -4,11 +4,13 @@ about: Create a report to help us improve --- -<!-- +**THIS IS NOT A SUPPORT CHANNEL!** +**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**, +please ask in **#synapse:matrix.org** (using a matrix.org account if necessary) -**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**: -You will likely get better support more quickly if you ask in ** #matrix:matrix.org ** ;) +<!-- +If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/ This is a bug report template. By following the instructions below and filling out the sections with your information, you will help the us to get all @@ -44,22 +46,26 @@ those (please be careful to remove any personal or private data). Please surroun <!-- IMPORTANT: please answer the following questions, to help us narrow down the problem --> <!-- Was this issue identified on matrix.org or another homeserver? --> -- **Homeserver**: +- **Homeserver**: If not matrix.org: <!-- -What version of Synapse is running? -You can find the Synapse version by inspecting the server headers (replace matrix.org with -your own homeserver domain): -$ curl -v https://matrix.org/_matrix/client/versions 2>&1 | grep "Server:" + What version of Synapse is running? + +You can find the Synapse version with this command: + +$ curl http://localhost:8008/_synapse/admin/v1/server_version + +(You may need to replace `localhost:8008` if Synapse is not configured to +listen on that port.) --> -- **Version**: +- **Version**: -- **Install method**: +- **Install method**: <!-- examples: package manager/git clone/pip --> -- **Platform**: +- **Platform**: <!-- Tell us about the environment in which your homeserver is operating distro, hardware, if it's running in a vm/container, etc. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1ead0d0030..fc22d89426 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,7 +1,12 @@ ### Pull Request Checklist -<!-- Please read CONTRIBUTING.rst before submitting your pull request --> +<!-- Please read CONTRIBUTING.md before submitting your pull request --> * [ ] Pull request is based on the develop branch -* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog) -* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off) +* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#changelog). The entry should: + - Be a short description of your change which makes sense to users. "Fixed a bug that prevented receiving messages from other servers." instead of "Moved X method from `EventStore` to `EventWorkerStore`.". + - Use markdown where necessary, mostly for `code blocks`. + - End with either a period (.) or an exclamation mark (!). + - Start with a capital letter. +* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#sign-off) +* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#code-style)) diff --git a/.gitignore b/.gitignore index e53d4908d5..af36c00cfa 100644 --- a/.gitignore +++ b/.gitignore @@ -7,9 +7,11 @@ *.egg-info *.lock *.pyc +*.snap *.tac _trial_temp/ _trial_temp*/ +/out # stuff that is likely to exist when you run a server locally /*.db diff --git a/AUTHORS.rst b/AUTHORS.rst index d8b4a846d8..014f16d4a2 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -1,34 +1,8 @@ -Erik Johnston <erik at matrix.org> - * HS core - * Federation API impl - -Mark Haines <mark at matrix.org> - * HS core - * Crypto - * Content repository - * CS v2 API impl - -Kegan Dougal <kegan at matrix.org> - * HS core - * CS v1 API impl - * AS API impl - -Paul "LeoNerd" Evans <paul at matrix.org> - * HS core - * Presence - * Typing Notifications - * Performance metrics and caching layer - -Dave Baker <dave at matrix.org> - * Push notifications - * Auth CS v2 impl - -Matthew Hodgson <matthew at matrix.org> - * General doc & housekeeping - * Vertobot/vertobridge matrix<->verto PoC - -Emmanuel Rohee <manu at matrix.org> - * Supporting iOS clients (testability and fallback registration) +The following is an incomplete list of people outside the core team who have +contributed to Synapse. It is no longer maintained: more recent contributions +are listed in the `changelog <CHANGES.md>`_. + +---- Turned to Dust <dwinslow86 at gmail.com> * ArchLinux installation instructions @@ -62,16 +36,16 @@ Christoph Witzany <christoph at web.crofting.com> * Add LDAP support for authentication Pierre Jaury <pierre at jaury.eu> -* Docker packaging + * Docker packaging Serban Constantin <serban.constantin at gmail dot com> * Small bug fix -Jason Robinson <jasonr at matrix.org> - * Minor fixes - Joseph Weston <joseph at weston.cloud> - + Add admin API for querying HS version + * Add admin API for querying HS version Benjamin Saunders <ben.e.saunders at gmail dot com> * Documentation improvements + +Werner Sembach <werner.sembach at fau dot de> + * Automatically remove a group/community when it is empty diff --git a/CHANGES.md b/CHANGES.md index f25c7d0c1a..b45eccf024 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,1513 @@ +Synapse 1.15.0rc1 (2020-06-09) +============================== + +Features +-------- + +- Advertise support for Client-Server API r0.6.0 and remove related unstable feature flags. ([\#6585](https://github.com/matrix-org/synapse/issues/6585)) +- Add an option to disable autojoining rooms for guest accounts. ([\#6637](https://github.com/matrix-org/synapse/issues/6637)) +- For SAML authentication, add the ability to pass email addresses to be added to new users' accounts via SAML attributes. Contributed by Christopher Cooper. ([\#7385](https://github.com/matrix-org/synapse/issues/7385)) +- Add admin APIs to allow server admins to manage users' devices. Contributed by @dklimpel. ([\#7481](https://github.com/matrix-org/synapse/issues/7481)) +- Add support for generating thumbnails for WebP images. Previously, users would see an empty box instead of preview image. ([\#7586](https://github.com/matrix-org/synapse/issues/7586)) +- Support the standardized `m.login.sso` user-interactive authentication flow. ([\#7630](https://github.com/matrix-org/synapse/issues/7630)) + + +Bugfixes +-------- + +- Allow new users to be registered via the admin API even if the monthly active user limit has been reached. Contributed by @dkimpel. ([\#7263](https://github.com/matrix-org/synapse/issues/7263)) +- Fix email notifications not being enabled for new users when created via the Admin API. ([\#7267](https://github.com/matrix-org/synapse/issues/7267)) +- Fix str placeholders in an instance of `PrepareDatabaseException`. Introduced in Synapse v1.8.0. ([\#7575](https://github.com/matrix-org/synapse/issues/7575)) +- Fix a bug in automatic user creation during first time login with `m.login.jwt`. Regression in v1.6.0. Contributed by @olof. ([\#7585](https://github.com/matrix-org/synapse/issues/7585)) +- Fix a bug causing the cross-signing keys to be ignored when resyncing a device list. ([\#7594](https://github.com/matrix-org/synapse/issues/7594)) +- Fix metrics failing when there is a large number of active background processes. ([\#7597](https://github.com/matrix-org/synapse/issues/7597)) +- Fix bug where returning rooms for a group would fail if it included a room that the server was not in. ([\#7599](https://github.com/matrix-org/synapse/issues/7599)) +- Fix duplicate key violation when persisting read markers. ([\#7607](https://github.com/matrix-org/synapse/issues/7607)) +- Prevent an entire iteration of the device list resync loop from failing if one server responds with a malformed result. ([\#7609](https://github.com/matrix-org/synapse/issues/7609)) +- Fix exceptions when fetching events from a remote host fails. ([\#7622](https://github.com/matrix-org/synapse/issues/7622)) +- Make `synctl restart` start synapse if it wasn't running. ([\#7624](https://github.com/matrix-org/synapse/issues/7624)) +- Pass device information through to the login endpoint when using the login fallback. ([\#7629](https://github.com/matrix-org/synapse/issues/7629)) +- Advertise the `m.login.token` login flow when OpenID Connect is enabled. ([\#7631](https://github.com/matrix-org/synapse/issues/7631)) +- Fix bug in account data replication stream. ([\#7656](https://github.com/matrix-org/synapse/issues/7656)) + + +Improved Documentation +---------------------- + +- Update the OpenBSD installation instructions. ([\#7587](https://github.com/matrix-org/synapse/issues/7587)) +- Advertise Python 3.8 support in `setup.py`. ([\#7602](https://github.com/matrix-org/synapse/issues/7602)) +- Add a link to `#synapse:matrix.org` in the troubleshooting section of the README. ([\#7603](https://github.com/matrix-org/synapse/issues/7603)) +- Clarifications to the admin api documentation. ([\#7647](https://github.com/matrix-org/synapse/issues/7647)) + + +Internal Changes +---------------- + +- Convert the identity handler to async/await. ([\#7561](https://github.com/matrix-org/synapse/issues/7561)) +- Improve query performance for fetching state from a PostgreSQL database. ([\#7567](https://github.com/matrix-org/synapse/issues/7567)) +- Speed up processing of federation stream RDATA rows. ([\#7584](https://github.com/matrix-org/synapse/issues/7584)) +- Add comment to systemd example to show postgresql dependency. ([\#7591](https://github.com/matrix-org/synapse/issues/7591)) +- Refactor `Ratelimiter` to limit the amount of expensive config value accesses. ([\#7595](https://github.com/matrix-org/synapse/issues/7595)) +- Convert groups handlers to async/await. ([\#7600](https://github.com/matrix-org/synapse/issues/7600)) +- Clean up exception handling in `SAML2ResponseResource`. ([\#7614](https://github.com/matrix-org/synapse/issues/7614)) +- Check that all asynchronous tasks succeed and general cleanup of `MonthlyActiveUsersTestCase` and `TestMauLimit`. ([\#7619](https://github.com/matrix-org/synapse/issues/7619)) +- Convert `get_user_id_by_threepid` to async/await. ([\#7620](https://github.com/matrix-org/synapse/issues/7620)) +- Switch to upstream `dh-virtualenv` rather than our fork for Debian package builds. ([\#7621](https://github.com/matrix-org/synapse/issues/7621)) +- Update CI scripts to check the number in the newsfile fragment. ([\#7623](https://github.com/matrix-org/synapse/issues/7623)) +- Check if the localpart of a Matrix ID is reserved for guest users earlier in the registration flow, as well as when responding to requests to `/register/available`. ([\#7625](https://github.com/matrix-org/synapse/issues/7625)) +- Minor cleanups to OpenID Connect integration. ([\#7628](https://github.com/matrix-org/synapse/issues/7628)) +- Attempt to fix flaky test: `PhoneHomeStatsTestCase.test_performance_100`. ([\#7634](https://github.com/matrix-org/synapse/issues/7634)) +- Fix typos of `m.olm.curve25519-aes-sha2` and `m.megolm.v1.aes-sha2` in comments, test files. ([\#7637](https://github.com/matrix-org/synapse/issues/7637)) +- Convert user directory, state deltas, and stats handlers to async/await. ([\#7640](https://github.com/matrix-org/synapse/issues/7640)) +- Remove some unused constants. ([\#7644](https://github.com/matrix-org/synapse/issues/7644)) +- Fix type information on `assert_*_is_admin` methods. ([\#7645](https://github.com/matrix-org/synapse/issues/7645)) +- Convert registration handler to async/await. ([\#7649](https://github.com/matrix-org/synapse/issues/7649)) + + +Synapse 1.14.0 (2020-05-28) +=========================== + +No significant changes. + + +Synapse 1.14.0rc2 (2020-05-27) +============================== + +Bugfixes +-------- + +- Fix cache config to not apply cache factor to event cache. Regression in v1.14.0rc1. ([\#7578](https://github.com/matrix-org/synapse/issues/7578)) +- Fix bug where `ReplicationStreamer` was not always started when replication was enabled. Bug introduced in v1.14.0rc1. ([\#7579](https://github.com/matrix-org/synapse/issues/7579)) +- Fix specifying individual cache factors for caches with special characters in their name. Regression in v1.14.0rc1. ([\#7580](https://github.com/matrix-org/synapse/issues/7580)) + + +Improved Documentation +---------------------- + +- Fix the OIDC `client_auth_method` value in the sample config. ([\#7581](https://github.com/matrix-org/synapse/issues/7581)) + + +Synapse 1.14.0rc1 (2020-05-26) +============================== + +Features +-------- + +- Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches. ([\#6391](https://github.com/matrix-org/synapse/issues/6391)) +- Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs). ([\#7256](https://github.com/matrix-org/synapse/issues/7256), [\#7457](https://github.com/matrix-org/synapse/issues/7457)) +- Add room details admin endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7317](https://github.com/matrix-org/synapse/issues/7317)) +- Allow for using more than one spam checker module at once. ([\#7435](https://github.com/matrix-org/synapse/issues/7435)) +- Add additional authentication checks for `m.room.power_levels` event per [MSC2209](https://github.com/matrix-org/matrix-doc/pull/2209). ([\#7502](https://github.com/matrix-org/synapse/issues/7502)) +- Implement room version 6 per [MSC2240](https://github.com/matrix-org/matrix-doc/pull/2240). ([\#7506](https://github.com/matrix-org/synapse/issues/7506)) +- Add highly experimental option to move event persistence off master. ([\#7281](https://github.com/matrix-org/synapse/issues/7281), [\#7374](https://github.com/matrix-org/synapse/issues/7374), [\#7436](https://github.com/matrix-org/synapse/issues/7436), [\#7440](https://github.com/matrix-org/synapse/issues/7440), [\#7475](https://github.com/matrix-org/synapse/issues/7475), [\#7490](https://github.com/matrix-org/synapse/issues/7490), [\#7491](https://github.com/matrix-org/synapse/issues/7491), [\#7492](https://github.com/matrix-org/synapse/issues/7492), [\#7493](https://github.com/matrix-org/synapse/issues/7493), [\#7495](https://github.com/matrix-org/synapse/issues/7495), [\#7515](https://github.com/matrix-org/synapse/issues/7515), [\#7516](https://github.com/matrix-org/synapse/issues/7516), [\#7517](https://github.com/matrix-org/synapse/issues/7517), [\#7542](https://github.com/matrix-org/synapse/issues/7542)) + + +Bugfixes +-------- + +- Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. ([\#7384](https://github.com/matrix-org/synapse/issues/7384)) +- Allow expired user accounts to log out their device sessions. ([\#7443](https://github.com/matrix-org/synapse/issues/7443)) +- Fix a bug that would cause Synapse not to resync out-of-sync device lists. ([\#7453](https://github.com/matrix-org/synapse/issues/7453)) +- Prevent rooms with 0 members or with invalid version strings from breaking group queries. ([\#7465](https://github.com/matrix-org/synapse/issues/7465)) +- Workaround for an upstream Twisted bug that caused Synapse to become unresponsive after startup. ([\#7473](https://github.com/matrix-org/synapse/issues/7473)) +- Fix Redis reconnection logic that can result in missed updates over replication if master reconnects to Redis without restarting. ([\#7482](https://github.com/matrix-org/synapse/issues/7482)) +- When sending `m.room.member` events, omit `displayname` and `avatar_url` if they aren't set instead of setting them to `null`. Contributed by Aaron Raimist. ([\#7497](https://github.com/matrix-org/synapse/issues/7497)) +- Fix incorrect `method` label on `synapse_http_matrixfederationclient_{requests,responses}` prometheus metrics. ([\#7503](https://github.com/matrix-org/synapse/issues/7503)) +- Ignore incoming presence events from other homeservers if presence is disabled locally. ([\#7508](https://github.com/matrix-org/synapse/issues/7508)) +- Fix a long-standing bug that broke the update remote profile background process. ([\#7511](https://github.com/matrix-org/synapse/issues/7511)) +- Hash passwords as early as possible during password reset. ([\#7538](https://github.com/matrix-org/synapse/issues/7538)) +- Fix bug where a local user leaving a room could fail under rare circumstances. ([\#7548](https://github.com/matrix-org/synapse/issues/7548)) +- Fix "Missing RelayState parameter" error when using user interactive authentication with SAML for some SAML providers. ([\#7552](https://github.com/matrix-org/synapse/issues/7552)) +- Fix exception `'GenericWorkerReplicationHandler' object has no attribute 'send_federation_ack'`, introduced in v1.13.0. ([\#7564](https://github.com/matrix-org/synapse/issues/7564)) +- `synctl` now warns if it was unable to stop Synapse and will not attempt to start Synapse if nothing was stopped. Contributed by Romain Bouyé. ([\#6598](https://github.com/matrix-org/synapse/issues/6598)) + + +Updates to the Docker image +--------------------------- + +- Update docker runtime image to Alpine v3.11. Contributed by @Starbix. ([\#7398](https://github.com/matrix-org/synapse/issues/7398)) + + +Improved Documentation +---------------------- + +- Update information about mapping providers for SAML and OpenID. ([\#7458](https://github.com/matrix-org/synapse/issues/7458)) +- Add additional reverse proxy example for Caddy v2. Contributed by Jeff Peeler. ([\#7463](https://github.com/matrix-org/synapse/issues/7463)) +- Fix copy-paste error in `ServerNoticesConfig` docstring. Contributed by @ptman. ([\#7477](https://github.com/matrix-org/synapse/issues/7477)) +- Improve the formatting of `reverse_proxy.md`. ([\#7514](https://github.com/matrix-org/synapse/issues/7514)) +- Change the systemd worker service to check that the worker config file exists instead of silently failing. Contributed by David Vo. ([\#7528](https://github.com/matrix-org/synapse/issues/7528)) +- Minor clarifications to the TURN docs. ([\#7533](https://github.com/matrix-org/synapse/issues/7533)) + + +Internal Changes +---------------- + +- Add typing annotations in `synapse.federation`. ([\#7382](https://github.com/matrix-org/synapse/issues/7382)) +- Convert the room handler to async/await. ([\#7396](https://github.com/matrix-org/synapse/issues/7396)) +- Improve performance of `get_e2e_cross_signing_key`. ([\#7428](https://github.com/matrix-org/synapse/issues/7428)) +- Improve performance of `mark_as_sent_devices_by_remote`. ([\#7429](https://github.com/matrix-org/synapse/issues/7429), [\#7562](https://github.com/matrix-org/synapse/issues/7562)) +- Add type hints to the SAML handler. ([\#7445](https://github.com/matrix-org/synapse/issues/7445)) +- Remove storage method `get_hosts_in_room` that is no longer called anywhere. ([\#7448](https://github.com/matrix-org/synapse/issues/7448)) +- Fix some typos in the `notice_expiry` templates. ([\#7449](https://github.com/matrix-org/synapse/issues/7449)) +- Convert the federation handler to async/await. ([\#7459](https://github.com/matrix-org/synapse/issues/7459)) +- Convert the search handler to async/await. ([\#7460](https://github.com/matrix-org/synapse/issues/7460)) +- Add type hints to `synapse.event_auth`. ([\#7505](https://github.com/matrix-org/synapse/issues/7505)) +- Convert the room member handler to async/await. ([\#7507](https://github.com/matrix-org/synapse/issues/7507)) +- Add type hints to room member handler. ([\#7513](https://github.com/matrix-org/synapse/issues/7513)) +- Fix typing annotations in `tests.replication`. ([\#7518](https://github.com/matrix-org/synapse/issues/7518)) +- Remove some redundant Python 2 support code. ([\#7519](https://github.com/matrix-org/synapse/issues/7519)) +- All endpoints now respond with a 200 OK for `OPTIONS` requests. ([\#7534](https://github.com/matrix-org/synapse/issues/7534), [\#7560](https://github.com/matrix-org/synapse/issues/7560)) +- Synapse now exports [detailed allocator statistics](https://doc.pypy.org/en/latest/gc_info.html#gc-get-stats) and basic GC timings as Prometheus metrics (`pypy_gc_time_seconds_total` and `pypy_memory_bytes`) when run under PyPy. Contributed by Ivan Shapovalov. ([\#7536](https://github.com/matrix-org/synapse/issues/7536)) +- Remove Ubuntu Cosmic and Disco from the list of distributions which we provide `.deb`s for, due to end-of-life. ([\#7539](https://github.com/matrix-org/synapse/issues/7539)) +- Make worker processes return a stubbed-out response to `GET /presence` requests. ([\#7545](https://github.com/matrix-org/synapse/issues/7545)) +- Optimise some references to `hs.config`. ([\#7546](https://github.com/matrix-org/synapse/issues/7546)) +- On upgrade room only send canonical alias once. ([\#7547](https://github.com/matrix-org/synapse/issues/7547)) +- Fix some indentation inconsistencies in the sample config. ([\#7550](https://github.com/matrix-org/synapse/issues/7550)) +- Include `synapse.http.site` in type checking. ([\#7553](https://github.com/matrix-org/synapse/issues/7553)) +- Fix some test code to not mangle stacktraces, to make it easier to debug errors. ([\#7554](https://github.com/matrix-org/synapse/issues/7554)) +- Refresh apt cache when building `dh_virtualenv` docker image. ([\#7555](https://github.com/matrix-org/synapse/issues/7555)) +- Stop logging some expected HTTP request errors as exceptions. ([\#7556](https://github.com/matrix-org/synapse/issues/7556), [\#7563](https://github.com/matrix-org/synapse/issues/7563)) +- Convert sending mail to async/await. ([\#7557](https://github.com/matrix-org/synapse/issues/7557)) +- Simplify `reap_monthly_active_users`. ([\#7558](https://github.com/matrix-org/synapse/issues/7558)) + + +Synapse 1.13.0 (2020-05-19) +=========================== + +This release brings some potential changes necessary for certain +configurations of Synapse: + +* If your Synapse is configured to use SSO and have a custom + `sso_redirect_confirm_template_dir` configuration option set, you will need + to duplicate the new `sso_auth_confirm.html`, `sso_auth_success.html` and + `sso_account_deactivated.html` templates into that directory. +* Synapse plugins using the `complete_sso_login` method of + `synapse.module_api.ModuleApi` should instead switch to the async/await + version, `complete_sso_login_async`, which includes additional checks. The + former version is now deprecated. +* A bug was introduced in Synapse 1.4.0 which could cause the room directory + to be incomplete or empty if Synapse was upgraded directly from v1.2.1 or + earlier, to versions between v1.4.0 and v1.12.x. + +Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes +and for general upgrade guidance. + + +Notice of change to the default `git` branch for Synapse +-------------------------------------------------------- + +With the release of Synapse 1.13.0, the default `git` branch for Synapse has +changed to `develop`, which is the development tip. This is more consistent with +common practice and modern `git` usage. + +The `master` branch, which tracks the latest release, is still available. It is +recommended that developers and distributors who have scripts which run builds +using the default branch of Synapse should therefore consider pinning their +scripts to `master`. + + +Internal Changes +---------------- + +- Update the version of dh-virtualenv we use to build debs, and add focal to the list of target distributions. ([\#7526](https://github.com/matrix-org/synapse/issues/7526)) + + +Synapse 1.13.0rc3 (2020-05-18) +============================== + +Bugfixes +-------- + +- Hash passwords as early as possible during registration. ([\#7523](https://github.com/matrix-org/synapse/issues/7523)) + + +Synapse 1.13.0rc2 (2020-05-14) +============================== + +Bugfixes +-------- + +- Fix a long-standing bug which could cause messages not to be sent over federation, when state events with state keys matching user IDs (such as custom user statuses) were received. ([\#7376](https://github.com/matrix-org/synapse/issues/7376)) +- Restore compatibility with non-compliant clients during the user interactive authentication process, fixing a problem introduced in v1.13.0rc1. ([\#7483](https://github.com/matrix-org/synapse/issues/7483)) + +Internal Changes +---------------- + +- Fix linting errors in new version of Flake8. ([\#7470](https://github.com/matrix-org/synapse/issues/7470)) + + +Synapse 1.13.0rc1 (2020-05-11) +============================== + +Features +-------- + +- Extend the `web_client_location` option to accept an absolute URL to use as a redirect. Adds a warning when running the web client on the same hostname as homeserver. Contributed by Martin Milata. ([\#7006](https://github.com/matrix-org/synapse/issues/7006)) +- Set `Referrer-Policy` header to `no-referrer` on media downloads. ([\#7009](https://github.com/matrix-org/synapse/issues/7009)) +- Add support for running replication over Redis when using workers. ([\#7040](https://github.com/matrix-org/synapse/issues/7040), [\#7325](https://github.com/matrix-org/synapse/issues/7325), [\#7352](https://github.com/matrix-org/synapse/issues/7352), [\#7401](https://github.com/matrix-org/synapse/issues/7401), [\#7427](https://github.com/matrix-org/synapse/issues/7427), [\#7439](https://github.com/matrix-org/synapse/issues/7439), [\#7446](https://github.com/matrix-org/synapse/issues/7446), [\#7450](https://github.com/matrix-org/synapse/issues/7450), [\#7454](https://github.com/matrix-org/synapse/issues/7454)) +- Admin API `POST /_synapse/admin/v1/join/<roomIdOrAlias>` to join users to a room like `auto_join_rooms` for creation of users. ([\#7051](https://github.com/matrix-org/synapse/issues/7051)) +- Add options to prevent users from changing their profile or associated 3PIDs. ([\#7096](https://github.com/matrix-org/synapse/issues/7096)) +- Support SSO in the user interactive authentication workflow. ([\#7102](https://github.com/matrix-org/synapse/issues/7102), [\#7186](https://github.com/matrix-org/synapse/issues/7186), [\#7279](https://github.com/matrix-org/synapse/issues/7279), [\#7343](https://github.com/matrix-org/synapse/issues/7343)) +- Allow server admins to define and enforce a password policy ([MSC2000](https://github.com/matrix-org/matrix-doc/issues/2000)). ([\#7118](https://github.com/matrix-org/synapse/issues/7118)) +- Improve the support for SSO authentication on the login fallback page. ([\#7152](https://github.com/matrix-org/synapse/issues/7152), [\#7235](https://github.com/matrix-org/synapse/issues/7235)) +- Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set. ([\#7153](https://github.com/matrix-org/synapse/issues/7153)) +- Admin users are no longer required to be in a room to create an alias for it. ([\#7191](https://github.com/matrix-org/synapse/issues/7191)) +- Require admin privileges to enable room encryption by default. This does not affect existing rooms. ([\#7230](https://github.com/matrix-org/synapse/issues/7230)) +- Add a config option for specifying the value of the Accept-Language HTTP header when generating URL previews. ([\#7265](https://github.com/matrix-org/synapse/issues/7265)) +- Allow `/requestToken` endpoints to hide the existence (or lack thereof) of 3PID associations on the homeserver. ([\#7315](https://github.com/matrix-org/synapse/issues/7315)) +- Add a configuration setting to tweak the threshold for dummy events. ([\#7422](https://github.com/matrix-org/synapse/issues/7422)) + + +Bugfixes +-------- + +- Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak. ([\#6573](https://github.com/matrix-org/synapse/issues/6573)) +- Fix single-sign on with CAS systems: pass the same service URL when requesting the CAS ticket and when calling the `proxyValidate` URL. Contributed by @Naugrimm. ([\#6634](https://github.com/matrix-org/synapse/issues/6634)) +- Fix missing field `default` when fetching user-defined push rules. ([\#6639](https://github.com/matrix-org/synapse/issues/6639)) +- Improve error responses when accessing remote public room lists. ([\#6899](https://github.com/matrix-org/synapse/issues/6899), [\#7368](https://github.com/matrix-org/synapse/issues/7368)) +- Transfer alias mappings on room upgrade. ([\#6946](https://github.com/matrix-org/synapse/issues/6946)) +- Ensure that a user interactive authentication session is tied to a single request. ([\#7068](https://github.com/matrix-org/synapse/issues/7068), [\#7455](https://github.com/matrix-org/synapse/issues/7455)) +- Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors. ([\#7089](https://github.com/matrix-org/synapse/issues/7089)) +- Return the proper error (`M_BAD_ALIAS`) when a non-existant canonical alias is provided. ([\#7109](https://github.com/matrix-org/synapse/issues/7109)) +- Fix a bug which meant that groups updates were not correctly replicated between workers. ([\#7117](https://github.com/matrix-org/synapse/issues/7117)) +- Fix starting workers when federation sending not split out. ([\#7133](https://github.com/matrix-org/synapse/issues/7133)) +- Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param. ([\#7150](https://github.com/matrix-org/synapse/issues/7150)) +- Fix error page being shown when a custom SAML handler attempted to redirect when processing an auth response. ([\#7151](https://github.com/matrix-org/synapse/issues/7151)) +- Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo. ([\#7155](https://github.com/matrix-org/synapse/issues/7155)) +- Fix excessive CPU usage by `prune_old_outbound_device_pokes` job. ([\#7159](https://github.com/matrix-org/synapse/issues/7159)) +- Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature. ([\#7177](https://github.com/matrix-org/synapse/issues/7177)) +- Fix a bug which could cause incorrect 'cyclic dependency' error. ([\#7178](https://github.com/matrix-org/synapse/issues/7178)) +- Fix a bug that could cause a user to be invited to a server notices (aka System Alerts) room without any notice being sent. ([\#7199](https://github.com/matrix-org/synapse/issues/7199)) +- Fix some worker-mode replication handling not being correctly recorded in CPU usage stats. ([\#7203](https://github.com/matrix-org/synapse/issues/7203)) +- Do not allow a deactivated user to login via SSO. ([\#7240](https://github.com/matrix-org/synapse/issues/7240), [\#7259](https://github.com/matrix-org/synapse/issues/7259)) +- Fix --help command-line argument. ([\#7249](https://github.com/matrix-org/synapse/issues/7249)) +- Fix room publish permissions not being checked on room creation. ([\#7260](https://github.com/matrix-org/synapse/issues/7260)) +- Reject unknown session IDs during user interactive authentication instead of silently creating a new session. ([\#7268](https://github.com/matrix-org/synapse/issues/7268)) +- Fix a SQL query introduced in Synapse 1.12.0 which could cause large amounts of logging to the postgres slow-query log. ([\#7274](https://github.com/matrix-org/synapse/issues/7274)) +- Persist user interactive authentication sessions across workers and Synapse restarts. ([\#7302](https://github.com/matrix-org/synapse/issues/7302)) +- Fixed backwards compatibility logic of the first value of `trusted_third_party_id_servers` being used for `account_threepid_delegates.email`, which occurs when the former, deprecated option is set and the latter is not. ([\#7316](https://github.com/matrix-org/synapse/issues/7316)) +- Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. ([\#7337](https://github.com/matrix-org/synapse/issues/7337), [\#7358](https://github.com/matrix-org/synapse/issues/7358)) +- Fix bad error handling that would cause Synapse to crash if it's provided with a YAML configuration file that's either empty or doesn't parse into a key-value map. ([\#7341](https://github.com/matrix-org/synapse/issues/7341)) +- Fix incorrect metrics reporting for `renew_attestations` background task. ([\#7344](https://github.com/matrix-org/synapse/issues/7344)) +- Prevent non-federating rooms from appearing in responses to federated `POST /publicRoom` requests when a filter was included. ([\#7367](https://github.com/matrix-org/synapse/issues/7367)) +- Fix a bug which would cause the room durectory to be incorrectly populated if Synapse was upgraded directly from v1.2.1 or earlier to v1.4.0 or later. Note that this fix does not apply retrospectively; see the [upgrade notes](UPGRADE.rst#upgrading-to-v1130) for more information. ([\#7387](https://github.com/matrix-org/synapse/issues/7387)) +- Fix bug in `EventContext.deserialize`. ([\#7393](https://github.com/matrix-org/synapse/issues/7393)) + + +Improved Documentation +---------------------- + +- Update Debian installation instructions to recommend installing the `virtualenv` package instead of `python3-virtualenv`. ([\#6892](https://github.com/matrix-org/synapse/issues/6892)) +- Improve the documentation for database configuration. ([\#6988](https://github.com/matrix-org/synapse/issues/6988)) +- Improve the documentation of application service configuration files. ([\#7091](https://github.com/matrix-org/synapse/issues/7091)) +- Update pre-built package name for FreeBSD. ([\#7107](https://github.com/matrix-org/synapse/issues/7107)) +- Update postgres docs with login troubleshooting information. ([\#7119](https://github.com/matrix-org/synapse/issues/7119)) +- Clean up INSTALL.md a bit. ([\#7141](https://github.com/matrix-org/synapse/issues/7141)) +- Add documentation for running a local CAS server for testing. ([\#7147](https://github.com/matrix-org/synapse/issues/7147)) +- Improve README.md by being explicit about public IP recommendation for TURN relaying. ([\#7167](https://github.com/matrix-org/synapse/issues/7167)) +- Fix a small typo in the `metrics_flags` config option. ([\#7171](https://github.com/matrix-org/synapse/issues/7171)) +- Update the contributed documentation on managing synapse workers with systemd, and bring it into the core distribution. ([\#7234](https://github.com/matrix-org/synapse/issues/7234)) +- Add documentation to the `password_providers` config option. Add known password provider implementations to docs. ([\#7238](https://github.com/matrix-org/synapse/issues/7238), [\#7248](https://github.com/matrix-org/synapse/issues/7248)) +- Modify suggested nginx reverse proxy configuration to match Synapse's default file upload size. Contributed by @ProCycleDev. ([\#7251](https://github.com/matrix-org/synapse/issues/7251)) +- Documentation of media_storage_providers options updated to avoid misunderstandings. Contributed by Tristan Lins. ([\#7272](https://github.com/matrix-org/synapse/issues/7272)) +- Add documentation on monitoring workers with Prometheus. ([\#7357](https://github.com/matrix-org/synapse/issues/7357)) +- Clarify endpoint usage in the users admin api documentation. ([\#7361](https://github.com/matrix-org/synapse/issues/7361)) + + +Deprecations and Removals +------------------------- + +- Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`. ([\#7137](https://github.com/matrix-org/synapse/issues/7137)) + + +Internal Changes +---------------- + +- Add benchmarks for LruCache. ([\#6446](https://github.com/matrix-org/synapse/issues/6446)) +- Return total number of users and profile attributes in admin users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#6881](https://github.com/matrix-org/synapse/issues/6881)) +- Change device list streams to have one row per ID. ([\#7010](https://github.com/matrix-org/synapse/issues/7010)) +- Remove concept of a non-limited stream. ([\#7011](https://github.com/matrix-org/synapse/issues/7011)) +- Move catchup of replication streams logic to worker. ([\#7024](https://github.com/matrix-org/synapse/issues/7024), [\#7195](https://github.com/matrix-org/synapse/issues/7195), [\#7226](https://github.com/matrix-org/synapse/issues/7226), [\#7239](https://github.com/matrix-org/synapse/issues/7239), [\#7286](https://github.com/matrix-org/synapse/issues/7286), [\#7290](https://github.com/matrix-org/synapse/issues/7290), [\#7318](https://github.com/matrix-org/synapse/issues/7318), [\#7326](https://github.com/matrix-org/synapse/issues/7326), [\#7378](https://github.com/matrix-org/synapse/issues/7378), [\#7421](https://github.com/matrix-org/synapse/issues/7421)) +- Convert some of synapse.rest.media to async/await. ([\#7110](https://github.com/matrix-org/synapse/issues/7110), [\#7184](https://github.com/matrix-org/synapse/issues/7184), [\#7241](https://github.com/matrix-org/synapse/issues/7241)) +- De-duplicate / remove unused REST code for login and auth. ([\#7115](https://github.com/matrix-org/synapse/issues/7115)) +- Convert `*StreamRow` classes to inner classes. ([\#7116](https://github.com/matrix-org/synapse/issues/7116)) +- Clean up some LoggingContext code. ([\#7120](https://github.com/matrix-org/synapse/issues/7120), [\#7181](https://github.com/matrix-org/synapse/issues/7181), [\#7183](https://github.com/matrix-org/synapse/issues/7183), [\#7408](https://github.com/matrix-org/synapse/issues/7408), [\#7426](https://github.com/matrix-org/synapse/issues/7426)) +- Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage. ([\#7128](https://github.com/matrix-org/synapse/issues/7128)) +- Refactored the CAS authentication logic to a separate class. ([\#7136](https://github.com/matrix-org/synapse/issues/7136)) +- Run replication streamers on workers. ([\#7146](https://github.com/matrix-org/synapse/issues/7146)) +- Add tests for outbound device pokes. ([\#7157](https://github.com/matrix-org/synapse/issues/7157)) +- Fix device list update stream ids going backward. ([\#7158](https://github.com/matrix-org/synapse/issues/7158)) +- Use `stream.current_token()` and remove `stream_positions()`. ([\#7172](https://github.com/matrix-org/synapse/issues/7172)) +- Move client command handling out of TCP protocol. ([\#7185](https://github.com/matrix-org/synapse/issues/7185)) +- Move server command handling out of TCP protocol. ([\#7187](https://github.com/matrix-org/synapse/issues/7187)) +- Fix consistency of HTTP status codes reported in log lines. ([\#7188](https://github.com/matrix-org/synapse/issues/7188)) +- Only run one background database update at a time. ([\#7190](https://github.com/matrix-org/synapse/issues/7190)) +- Remove sent outbound device list pokes from the database. ([\#7192](https://github.com/matrix-org/synapse/issues/7192)) +- Add a background database update job to clear out duplicate `device_lists_outbound_pokes`. ([\#7193](https://github.com/matrix-org/synapse/issues/7193)) +- Remove some extraneous debugging log lines. ([\#7207](https://github.com/matrix-org/synapse/issues/7207)) +- Add explicit Python build tooling as dependencies for the snapcraft build. ([\#7213](https://github.com/matrix-org/synapse/issues/7213)) +- Add typing information to federation server code. ([\#7219](https://github.com/matrix-org/synapse/issues/7219)) +- Extend room admin api (`GET /_synapse/admin/v1/rooms`) with additional attributes. ([\#7225](https://github.com/matrix-org/synapse/issues/7225)) +- Unblacklist '/upgrade creates a new room' sytest for workers. ([\#7228](https://github.com/matrix-org/synapse/issues/7228)) +- Remove redundant checks on `daemonize` from synctl. ([\#7233](https://github.com/matrix-org/synapse/issues/7233)) +- Upgrade jQuery to v3.4.1 on fallback login/registration pages. ([\#7236](https://github.com/matrix-org/synapse/issues/7236)) +- Change log line that told user to implement onLogin/onRegister fallback js functions to a warning, instead of an info, so it's more visible. ([\#7237](https://github.com/matrix-org/synapse/issues/7237)) +- Correct the parameters of a test fixture. Contributed by Isaiah Singletary. ([\#7243](https://github.com/matrix-org/synapse/issues/7243)) +- Convert auth handler to async/await. ([\#7261](https://github.com/matrix-org/synapse/issues/7261)) +- Add some unit tests for replication. ([\#7278](https://github.com/matrix-org/synapse/issues/7278)) +- Improve typing annotations in `synapse.replication.tcp.streams.Stream`. ([\#7291](https://github.com/matrix-org/synapse/issues/7291)) +- Reduce log verbosity of url cache cleanup tasks. ([\#7295](https://github.com/matrix-org/synapse/issues/7295)) +- Fix sample SAML Service Provider configuration. Contributed by @frcl. ([\#7300](https://github.com/matrix-org/synapse/issues/7300)) +- Fix StreamChangeCache to work with multiple entities changing on the same stream id. ([\#7303](https://github.com/matrix-org/synapse/issues/7303)) +- Fix an incorrect import in IdentityHandler. ([\#7319](https://github.com/matrix-org/synapse/issues/7319)) +- Reduce logging verbosity for successful federation requests. ([\#7321](https://github.com/matrix-org/synapse/issues/7321)) +- Convert some federation handler code to async/await. ([\#7338](https://github.com/matrix-org/synapse/issues/7338)) +- Fix collation for postgres for unit tests. ([\#7359](https://github.com/matrix-org/synapse/issues/7359)) +- Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await. ([\#7363](https://github.com/matrix-org/synapse/issues/7363)) +- Add an `instance_name` to `RDATA` and `POSITION` replication commands. ([\#7364](https://github.com/matrix-org/synapse/issues/7364)) +- Thread through instance name to replication client. ([\#7369](https://github.com/matrix-org/synapse/issues/7369)) +- Convert synapse.server_notices to async/await. ([\#7394](https://github.com/matrix-org/synapse/issues/7394)) +- Convert synapse.notifier to async/await. ([\#7395](https://github.com/matrix-org/synapse/issues/7395)) +- Fix issues with the Python package manifest. ([\#7404](https://github.com/matrix-org/synapse/issues/7404)) +- Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request. ([\#7420](https://github.com/matrix-org/synapse/issues/7420)) +- Speed up fetching device lists changes when handling `/sync` requests. ([\#7423](https://github.com/matrix-org/synapse/issues/7423)) +- Run group attestation renewal in series rather than parallel for performance. ([\#7442](https://github.com/matrix-org/synapse/issues/7442)) + + +Synapse 1.12.4 (2020-04-23) +=========================== + +No significant changes. + + +Synapse 1.12.4rc1 (2020-04-22) +============================== + +Features +-------- + +- Always send users their own device updates. ([\#7160](https://github.com/matrix-org/synapse/issues/7160)) +- Add support for handling GET requests for `account_data` on a worker. ([\#7311](https://github.com/matrix-org/synapse/issues/7311)) + + +Bugfixes +-------- + +- Fix a bug that prevented cross-signing with users on worker-mode synapses. ([\#7255](https://github.com/matrix-org/synapse/issues/7255)) +- Do not treat display names as globs in push rules. ([\#7271](https://github.com/matrix-org/synapse/issues/7271)) +- Fix a bug with cross-signing devices belonging to remote users who did not share a room with any user on the local homeserver. ([\#7289](https://github.com/matrix-org/synapse/issues/7289)) + +Synapse 1.12.3 (2020-04-03) +=========================== + +- Remove the the pin to Pillow 7.0 which was introduced in Synapse 1.12.2, and +correctly fix the issue with building the Debian packages. ([\#7212](https://github.com/matrix-org/synapse/issues/7212)) + +Synapse 1.12.2 (2020-04-02) +=========================== + +This release works around [an issue](https://github.com/matrix-org/synapse/issues/7208) with building the debian packages. + +No other significant changes since 1.12.1. + +Synapse 1.12.1 (2020-04-02) +=========================== + +No significant changes since 1.12.1rc1. + + +Synapse 1.12.1rc1 (2020-03-31) +============================== + +Bugfixes +-------- + +- Fix starting workers when federation sending not split out. ([\#7133](https://github.com/matrix-org/synapse/issues/7133)). Introduced in v1.12.0. +- Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo. ([\#7155](https://github.com/matrix-org/synapse/issues/7155)). Introduced in v1.12.0rc1. +- Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature. ([\#7177](https://github.com/matrix-org/synapse/issues/7177)). Introduced in v1.11.0. + +Synapse 1.12.0 (2020-03-23) +=========================== + +Debian packages and Docker images are rebuilt using the latest versions of +dependency libraries, including Twisted 20.3.0. **Please see security advisory +below**. + +Potential slow database update during upgrade +--------------------------------------------- + +Synapse 1.12.0 includes a database update which is run as part of the upgrade, +and which may take some time (several hours in the case of a large +server). Synapse will not respond to HTTP requests while this update is taking +place. For imformation on seeing if you are affected, and workaround if you +are, see the [upgrade notes](UPGRADE.rst#upgrading-to-v1120). + +Security advisory +----------------- + +Synapse may be vulnerable to request-smuggling attacks when it is used with a +reverse-proxy. The vulnerabilties are fixed in Twisted 20.3.0, and are +described in +[CVE-2020-10108](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-10108) +and +[CVE-2020-10109](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-10109). +For a good introduction to this class of request-smuggling attacks, see +https://portswigger.net/research/http-desync-attacks-request-smuggling-reborn. + +We are not aware of these vulnerabilities being exploited in the wild, and +do not believe that they are exploitable with current versions of any reverse +proxies. Nevertheless, we recommend that all Synapse administrators ensure that +they have the latest versions of the Twisted library to ensure that their +installation remains secure. + +* Administrators using the [`matrix.org` Docker + image](https://hub.docker.com/r/matrixdotorg/synapse/) or the [Debian/Ubuntu + packages from + `matrix.org`](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#matrixorg-packages) + should ensure that they have version 1.12.0 installed: these images include + Twisted 20.3.0. +* Administrators who have [installed Synapse from + source](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#installing-from-source) + should upgrade Twisted within their virtualenv by running: + ```sh + <path_to_virtualenv>/bin/pip install 'Twisted>=20.3.0' + ``` +* Administrators who have installed Synapse from distribution packages should + consult the information from their distributions. + +The `matrix.org` Synapse instance was not vulnerable to these vulnerabilities. + +Advance notice of change to the default `git` branch for Synapse +---------------------------------------------------------------- + +Currently, the default `git` branch for Synapse is `master`, which tracks the +latest release. + +After the release of Synapse 1.13.0, we intend to change this default to +`develop`, which is the development tip. This is more consistent with common +practice and modern `git` usage. + +Although we try to keep `develop` in a stable state, there may be occasions +where regressions creep in. Developers and distributors who have scripts which +run builds using the default branch of `Synapse` should therefore consider +pinning their scripts to `master`. + + +Synapse 1.12.0rc1 (2020-03-19) +============================== + +Features +-------- + +- Changes related to room alias management ([MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432)): + - Publishing/removing a room from the room directory now requires the user to have a power level capable of modifying the canonical alias, instead of the room aliases. ([\#6965](https://github.com/matrix-org/synapse/issues/6965)) + - Validate the `alt_aliases` property of canonical alias events. ([\#6971](https://github.com/matrix-org/synapse/issues/6971)) + - Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases. ([\#6986](https://github.com/matrix-org/synapse/issues/6986)) + - Implement updated authorization rules and redaction rules for aliases events, from [MSC2261](https://github.com/matrix-org/matrix-doc/pull/2261) and [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). ([\#7037](https://github.com/matrix-org/synapse/issues/7037)) + - Stop sending m.room.aliases events during room creation and upgrade. ([\#6941](https://github.com/matrix-org/synapse/issues/6941)) + - Synapse no longer uses room alias events to calculate room names for push notifications. ([\#6966](https://github.com/matrix-org/synapse/issues/6966)) + - The room list endpoint no longer returns a list of aliases. ([\#6970](https://github.com/matrix-org/synapse/issues/6970)) + - Remove special handling of aliases events from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260) added in v1.10.0rc1. ([\#7034](https://github.com/matrix-org/synapse/issues/7034)) +- Expose the `synctl`, `hash_password` and `generate_config` commands in the snapcraft package. Contributed by @devec0. ([\#6315](https://github.com/matrix-org/synapse/issues/6315)) +- Check that server_name is correctly set before running database updates. ([\#6982](https://github.com/matrix-org/synapse/issues/6982)) +- Break down monthly active users by `appservice_id` and emit via Prometheus. ([\#7030](https://github.com/matrix-org/synapse/issues/7030)) +- Render a configurable and comprehensible error page if something goes wrong during the SAML2 authentication process. ([\#7058](https://github.com/matrix-org/synapse/issues/7058), [\#7067](https://github.com/matrix-org/synapse/issues/7067)) +- Add an optional parameter to control whether other sessions are logged out when a user's password is modified. ([\#7085](https://github.com/matrix-org/synapse/issues/7085)) +- Add prometheus metrics for the number of active pushers. ([\#7103](https://github.com/matrix-org/synapse/issues/7103), [\#7106](https://github.com/matrix-org/synapse/issues/7106)) +- Improve performance when making HTTPS requests to sygnal, sydent, etc, by sharing the SSL context object between connections. ([\#7094](https://github.com/matrix-org/synapse/issues/7094)) + + +Bugfixes +-------- + +- When a user's profile is updated via the admin API, also generate a displayname/avatar update for that user in each room. ([\#6572](https://github.com/matrix-org/synapse/issues/6572)) +- Fix a couple of bugs in email configuration handling. ([\#6962](https://github.com/matrix-org/synapse/issues/6962)) +- Fix an issue affecting worker-based deployments where replication would stop working, necessitating a full restart, after joining a large room. ([\#6967](https://github.com/matrix-org/synapse/issues/6967)) +- Fix `duplicate key` error which was logged when rejoining a room over federation. ([\#6968](https://github.com/matrix-org/synapse/issues/6968)) +- Prevent user from setting 'deactivated' to anything other than a bool on the v2 PUT /users Admin API. ([\#6990](https://github.com/matrix-org/synapse/issues/6990)) +- Fix py35-old CI by using native tox package. ([\#7018](https://github.com/matrix-org/synapse/issues/7018)) +- Fix a bug causing `org.matrix.dummy_event` to be included in responses from `/sync`. ([\#7035](https://github.com/matrix-org/synapse/issues/7035)) +- Fix a bug that renders UTF-8 text files incorrectly when loaded from media. Contributed by @TheStranjer. ([\#7044](https://github.com/matrix-org/synapse/issues/7044)) +- Fix a bug that would cause Synapse to respond with an error about event visibility if a client tried to request the state of a room at a given token. ([\#7066](https://github.com/matrix-org/synapse/issues/7066)) +- Repair a data-corruption issue which was introduced in Synapse 1.10, and fixed in Synapse 1.11, and which could cause `/sync` to return with 404 errors about missing events and unknown rooms. ([\#7070](https://github.com/matrix-org/synapse/issues/7070)) +- Fix a bug causing account validity renewal emails to be sent even if the feature is turned off in some cases. ([\#7074](https://github.com/matrix-org/synapse/issues/7074)) + + +Improved Documentation +---------------------- + +- Updated CentOS8 install instructions. Contributed by Richard Kellner. ([\#6925](https://github.com/matrix-org/synapse/issues/6925)) +- Fix `POSTGRES_INITDB_ARGS` in the `contrib/docker/docker-compose.yml` example docker-compose configuration. ([\#6984](https://github.com/matrix-org/synapse/issues/6984)) +- Change date in [INSTALL.md](./INSTALL.md#tls-certificates) for last date of getting TLS certificates to November 2019. ([\#7015](https://github.com/matrix-org/synapse/issues/7015)) +- Document that the fallback auth endpoints must be routed to the same worker node as the register endpoints. ([\#7048](https://github.com/matrix-org/synapse/issues/7048)) + + +Deprecations and Removals +------------------------- + +- Remove the unused query_auth federation endpoint per [MSC2451](https://github.com/matrix-org/matrix-doc/pull/2451). ([\#7026](https://github.com/matrix-org/synapse/issues/7026)) + + +Internal Changes +---------------- + +- Add type hints to `logging/context.py`. ([\#6309](https://github.com/matrix-org/synapse/issues/6309)) +- Add some clarifications to `README.md` in the database schema directory. ([\#6615](https://github.com/matrix-org/synapse/issues/6615)) +- Refactoring work in preparation for changing the event redaction algorithm. ([\#6874](https://github.com/matrix-org/synapse/issues/6874), [\#6875](https://github.com/matrix-org/synapse/issues/6875), [\#6983](https://github.com/matrix-org/synapse/issues/6983), [\#7003](https://github.com/matrix-org/synapse/issues/7003)) +- Improve performance of v2 state resolution for large rooms. ([\#6952](https://github.com/matrix-org/synapse/issues/6952), [\#7095](https://github.com/matrix-org/synapse/issues/7095)) +- Reduce time spent doing GC, by freezing objects on startup. ([\#6953](https://github.com/matrix-org/synapse/issues/6953)) +- Minor perfermance fixes to `get_auth_chain_ids`. ([\#6954](https://github.com/matrix-org/synapse/issues/6954)) +- Don't record remote cross-signing keys in the `devices` table. ([\#6956](https://github.com/matrix-org/synapse/issues/6956)) +- Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions. ([\#6957](https://github.com/matrix-org/synapse/issues/6957)) +- Merge worker apps together. ([\#6964](https://github.com/matrix-org/synapse/issues/6964), [\#7002](https://github.com/matrix-org/synapse/issues/7002), [\#7055](https://github.com/matrix-org/synapse/issues/7055), [\#7104](https://github.com/matrix-org/synapse/issues/7104)) +- Remove redundant `store_room` call from `FederationHandler._process_received_pdu`. ([\#6979](https://github.com/matrix-org/synapse/issues/6979)) +- Update warning for incorrect database collation/ctype to include link to documentation. ([\#6985](https://github.com/matrix-org/synapse/issues/6985)) +- Add some type annotations to the database storage classes. ([\#6987](https://github.com/matrix-org/synapse/issues/6987)) +- Port `synapse.handlers.presence` to async/await. ([\#6991](https://github.com/matrix-org/synapse/issues/6991), [\#7019](https://github.com/matrix-org/synapse/issues/7019)) +- Add some type annotations to the federation base & client classes. ([\#6995](https://github.com/matrix-org/synapse/issues/6995)) +- Port `synapse.rest.keys` to async/await. ([\#7020](https://github.com/matrix-org/synapse/issues/7020)) +- Add a type check to `is_verified` when processing room keys. ([\#7045](https://github.com/matrix-org/synapse/issues/7045)) +- Add type annotations and comments to the auth handler. ([\#7063](https://github.com/matrix-org/synapse/issues/7063)) + + +Synapse 1.11.1 (2020-03-03) +=========================== + +This release includes a security fix impacting installations using Single Sign-On (i.e. SAML2 or CAS) for authentication. Administrators of such installations are encouraged to upgrade as soon as possible. + +The release also includes fixes for a couple of other bugs. + +Bugfixes +-------- + +- Add a confirmation step to the SSO login flow before redirecting users to the redirect URL. ([b2bd54a2](https://github.com/matrix-org/synapse/commit/b2bd54a2e31d9a248f73fadb184ae9b4cbdb49f9), [65c73cdf](https://github.com/matrix-org/synapse/commit/65c73cdfec1876a9fec2fd2c3a74923cd146fe0b), [a0178df1](https://github.com/matrix-org/synapse/commit/a0178df10422a76fd403b82d2b2a4ed28a9a9d1e)) +- Fixed set a user as an admin with the admin API `PUT /_synapse/admin/v2/users/<user_id>`. Contributed by @dklimpel. ([\#6910](https://github.com/matrix-org/synapse/issues/6910)) +- Fix bug introduced in Synapse 1.11.0 which sometimes caused errors when joining rooms over federation, with `'coroutine' object has no attribute 'event_id'`. ([\#6996](https://github.com/matrix-org/synapse/issues/6996)) + + +Synapse 1.11.0 (2020-02-21) +=========================== + +Improved Documentation +---------------------- + +- Small grammatical fixes to the ACME v1 deprecation notice. ([\#6944](https://github.com/matrix-org/synapse/issues/6944)) + + +Synapse 1.11.0rc1 (2020-02-19) +============================== + +Features +-------- + +- Admin API to add or modify threepids of user accounts. ([\#6769](https://github.com/matrix-org/synapse/issues/6769)) +- Limit the number of events that can be requested by the backfill federation API to 100. ([\#6864](https://github.com/matrix-org/synapse/issues/6864)) +- Add ability to run some group APIs on workers. ([\#6866](https://github.com/matrix-org/synapse/issues/6866)) +- Reject device display names over 100 characters in length to prevent abuse. ([\#6882](https://github.com/matrix-org/synapse/issues/6882)) +- Add ability to route federation user device queries to workers. ([\#6873](https://github.com/matrix-org/synapse/issues/6873)) +- The result of a user directory search can now be filtered via the spam checker. ([\#6888](https://github.com/matrix-org/synapse/issues/6888)) +- Implement new `GET /_matrix/client/unstable/org.matrix.msc2432/rooms/{roomId}/aliases` endpoint as per [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432). ([\#6939](https://github.com/matrix-org/synapse/issues/6939), [\#6948](https://github.com/matrix-org/synapse/issues/6948), [\#6949](https://github.com/matrix-org/synapse/issues/6949)) +- Stop sending `m.room.alias` events wheng adding / removing aliases. Check `alt_aliases` in the latest `m.room.canonical_alias` event when deleting an alias. ([\#6904](https://github.com/matrix-org/synapse/issues/6904)) +- Change the default power levels of invites, tombstones and server ACLs for new rooms. ([\#6834](https://github.com/matrix-org/synapse/issues/6834)) + +Bugfixes +-------- + +- Fixed third party event rules function `on_create_room`'s return value being ignored. ([\#6781](https://github.com/matrix-org/synapse/issues/6781)) +- Allow URL-encoded User IDs on `/_synapse/admin/v2/users/<user_id>[/admin]` endpoints. Thanks to @NHAS for reporting. ([\#6825](https://github.com/matrix-org/synapse/issues/6825)) +- Fix Synapse refusing to start if `federation_certificate_verification_whitelist` option is blank. ([\#6849](https://github.com/matrix-org/synapse/issues/6849)) +- Fix errors from logging in the purge jobs related to the message retention policies support. ([\#6945](https://github.com/matrix-org/synapse/issues/6945)) +- Return a 404 instead of 200 for querying information of a non-existant user through the admin API. ([\#6901](https://github.com/matrix-org/synapse/issues/6901)) + + +Updates to the Docker image +--------------------------- + +- The deprecated "generate-config-on-the-fly" mode is no longer supported. ([\#6918](https://github.com/matrix-org/synapse/issues/6918)) + + +Improved Documentation +---------------------- + +- Add details of PR merge strategy to contributing docs. ([\#6846](https://github.com/matrix-org/synapse/issues/6846)) +- Spell out that the last event sent to a room won't be deleted by a purge. ([\#6891](https://github.com/matrix-org/synapse/issues/6891)) +- Update Synapse's documentation to warn about the deprecation of ACME v1. ([\#6905](https://github.com/matrix-org/synapse/issues/6905), [\#6907](https://github.com/matrix-org/synapse/issues/6907), [\#6909](https://github.com/matrix-org/synapse/issues/6909)) +- Add documentation for the spam checker. ([\#6906](https://github.com/matrix-org/synapse/issues/6906)) +- Fix worker docs to point `/publicised_groups` API correctly. ([\#6938](https://github.com/matrix-org/synapse/issues/6938)) +- Clean up and update docs on setting up federation. ([\#6940](https://github.com/matrix-org/synapse/issues/6940)) +- Add a warning about indentation to generated configuration files. ([\#6920](https://github.com/matrix-org/synapse/issues/6920)) +- Databases created using the compose file in contrib/docker will now always have correct encoding and locale settings. Contributed by Fridtjof Mund. ([\#6921](https://github.com/matrix-org/synapse/issues/6921)) +- Update pip install directions in readme to avoid error when using zsh. ([\#6855](https://github.com/matrix-org/synapse/issues/6855)) + + +Deprecations and Removals +------------------------- + +- Remove `m.lazy_load_members` from `unstable_features` since lazy loading is in the stable Client-Server API version r0.5.0. ([\#6877](https://github.com/matrix-org/synapse/issues/6877)) + + +Internal Changes +---------------- + +- Add type hints to `SyncHandler`. ([\#6821](https://github.com/matrix-org/synapse/issues/6821)) +- Refactoring work in preparation for changing the event redaction algorithm. ([\#6823](https://github.com/matrix-org/synapse/issues/6823), [\#6827](https://github.com/matrix-org/synapse/issues/6827), [\#6854](https://github.com/matrix-org/synapse/issues/6854), [\#6856](https://github.com/matrix-org/synapse/issues/6856), [\#6857](https://github.com/matrix-org/synapse/issues/6857), [\#6858](https://github.com/matrix-org/synapse/issues/6858)) +- Fix stacktraces when using `ObservableDeferred` and async/await. ([\#6836](https://github.com/matrix-org/synapse/issues/6836)) +- Port much of `synapse.handlers.federation` to async/await. ([\#6837](https://github.com/matrix-org/synapse/issues/6837), [\#6840](https://github.com/matrix-org/synapse/issues/6840)) +- Populate `rooms.room_version` database column at startup, rather than in a background update. ([\#6847](https://github.com/matrix-org/synapse/issues/6847)) +- Reduce amount we log at `INFO` level. ([\#6833](https://github.com/matrix-org/synapse/issues/6833), [\#6862](https://github.com/matrix-org/synapse/issues/6862)) +- Remove unused `get_room_stats_state` method. ([\#6869](https://github.com/matrix-org/synapse/issues/6869)) +- Add typing to `synapse.federation.sender` and port to async/await. ([\#6871](https://github.com/matrix-org/synapse/issues/6871)) +- Refactor `_EventInternalMetadata` object to improve type safety. ([\#6872](https://github.com/matrix-org/synapse/issues/6872)) +- Add an additional entry to the SyTest blacklist for worker mode. ([\#6883](https://github.com/matrix-org/synapse/issues/6883)) +- Fix the use of sed in the linting scripts when using BSD sed. ([\#6887](https://github.com/matrix-org/synapse/issues/6887)) +- Add type hints to the spam checker module. ([\#6915](https://github.com/matrix-org/synapse/issues/6915)) +- Convert the directory handler tests to use HomeserverTestCase. ([\#6919](https://github.com/matrix-org/synapse/issues/6919)) +- Increase DB/CPU perf of `_is_server_still_joined` check. ([\#6936](https://github.com/matrix-org/synapse/issues/6936)) +- Tiny optimisation for incoming HTTP request dispatch. ([\#6950](https://github.com/matrix-org/synapse/issues/6950)) + + +Synapse 1.10.1 (2020-02-17) +=========================== + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.10.0 which would cause room state to be cleared in the database if Synapse was upgraded direct from 1.2.1 or earlier to 1.10.0. ([\#6924](https://github.com/matrix-org/synapse/issues/6924)) + + +Synapse 1.10.0 (2020-02-12) +=========================== + +**WARNING to client developers**: As of this release Synapse validates `client_secret` parameters in the Client-Server API as per the spec. See [\#6766](https://github.com/matrix-org/synapse/issues/6766) for details. + +Updates to the Docker image +--------------------------- + +- Update the docker images to Alpine Linux 3.11. ([\#6897](https://github.com/matrix-org/synapse/issues/6897)) + + +Synapse 1.10.0rc5 (2020-02-11) +============================== + +Bugfixes +-------- + +- Fix the filtering introduced in 1.10.0rc3 to also apply to the state blocks returned by `/sync`. ([\#6884](https://github.com/matrix-org/synapse/issues/6884)) + +Synapse 1.10.0rc4 (2020-02-11) +============================== + +This release candidate was built incorrectly and is superceded by 1.10.0rc5. + +Synapse 1.10.0rc3 (2020-02-10) +============================== + +Features +-------- + +- Filter out `m.room.aliases` from the CS API to mitigate abuse while a better solution is specced. ([\#6878](https://github.com/matrix-org/synapse/issues/6878)) + + +Internal Changes +---------------- + +- Fix continuous integration failures with old versions of `pip`, which were introduced by a release of the `zipp` library. ([\#6880](https://github.com/matrix-org/synapse/issues/6880)) + + +Synapse 1.10.0rc2 (2020-02-06) +============================== + +Bugfixes +-------- + +- Fix an issue with cross-signing where device signatures were not sent to remote servers. ([\#6844](https://github.com/matrix-org/synapse/issues/6844)) +- Fix to the unknown remote device detection which was introduced in 1.10.rc1. ([\#6848](https://github.com/matrix-org/synapse/issues/6848)) + + +Internal Changes +---------------- + +- Detect unexpected sender keys on remote encrypted events and resync device lists. ([\#6850](https://github.com/matrix-org/synapse/issues/6850)) + + +Synapse 1.10.0rc1 (2020-01-31) +============================== + +Features +-------- + +- Add experimental support for updated authorization rules for aliases events, from [MSC2260](https://github.com/matrix-org/matrix-doc/pull/2260). ([\#6787](https://github.com/matrix-org/synapse/issues/6787), [\#6790](https://github.com/matrix-org/synapse/issues/6790), [\#6794](https://github.com/matrix-org/synapse/issues/6794)) + + +Bugfixes +-------- + +- Warn if postgres database has a non-C locale, as that can cause issues when upgrading locales (e.g. due to upgrading OS). ([\#6734](https://github.com/matrix-org/synapse/issues/6734)) +- Minor fixes to `PUT /_synapse/admin/v2/users` admin api. ([\#6761](https://github.com/matrix-org/synapse/issues/6761)) +- Validate `client_secret` parameter using the regex provided by the Client-Server API, temporarily allowing `:` characters for older clients. The `:` character will be removed in a future release. ([\#6767](https://github.com/matrix-org/synapse/issues/6767)) +- Fix persisting redaction events that have been redacted (or otherwise don't have a redacts key). ([\#6771](https://github.com/matrix-org/synapse/issues/6771)) +- Fix outbound federation request metrics. ([\#6795](https://github.com/matrix-org/synapse/issues/6795)) +- Fix bug where querying a remote user's device keys that weren't cached resulted in only returning a single device. ([\#6796](https://github.com/matrix-org/synapse/issues/6796)) +- Fix race in federation sender worker that delayed sending of device updates. ([\#6799](https://github.com/matrix-org/synapse/issues/6799), [\#6800](https://github.com/matrix-org/synapse/issues/6800)) +- Fix bug where Synapse didn't invalidate cache of remote users' devices when Synapse left a room. ([\#6801](https://github.com/matrix-org/synapse/issues/6801)) +- Fix waking up other workers when remote server is detected to have come back online. ([\#6811](https://github.com/matrix-org/synapse/issues/6811)) + + +Improved Documentation +---------------------- + +- Clarify documentation related to `user_dir` and `federation_reader` workers. ([\#6775](https://github.com/matrix-org/synapse/issues/6775)) + + +Internal Changes +---------------- + +- Record room versions in the `rooms` table. ([\#6729](https://github.com/matrix-org/synapse/issues/6729), [\#6788](https://github.com/matrix-org/synapse/issues/6788), [\#6810](https://github.com/matrix-org/synapse/issues/6810)) +- Propagate cache invalidates from workers to other workers. ([\#6748](https://github.com/matrix-org/synapse/issues/6748)) +- Remove some unnecessary admin handler abstraction methods. ([\#6751](https://github.com/matrix-org/synapse/issues/6751)) +- Add some debugging for media storage providers. ([\#6757](https://github.com/matrix-org/synapse/issues/6757)) +- Detect unknown remote devices and mark cache as stale. ([\#6776](https://github.com/matrix-org/synapse/issues/6776), [\#6819](https://github.com/matrix-org/synapse/issues/6819)) +- Attempt to resync remote users' devices when detected as stale. ([\#6786](https://github.com/matrix-org/synapse/issues/6786)) +- Delete current state from the database when server leaves a room. ([\#6792](https://github.com/matrix-org/synapse/issues/6792)) +- When a client asks for a remote user's device keys check if the local cache for that user has been marked as potentially stale. ([\#6797](https://github.com/matrix-org/synapse/issues/6797)) +- Add background update to clean out left rooms from current state. ([\#6802](https://github.com/matrix-org/synapse/issues/6802), [\#6816](https://github.com/matrix-org/synapse/issues/6816)) +- Refactoring work in preparation for changing the event redaction algorithm. ([\#6803](https://github.com/matrix-org/synapse/issues/6803), [\#6805](https://github.com/matrix-org/synapse/issues/6805), [\#6806](https://github.com/matrix-org/synapse/issues/6806), [\#6807](https://github.com/matrix-org/synapse/issues/6807), [\#6820](https://github.com/matrix-org/synapse/issues/6820)) + + +Synapse 1.9.1 (2020-01-28) +========================== + +Bugfixes +-------- + +- Fix bug where setting `mau_limit_reserved_threepids` config would cause Synapse to refuse to start. ([\#6793](https://github.com/matrix-org/synapse/issues/6793)) + + +Synapse 1.9.0 (2020-01-23) +========================== + +**WARNING**: As of this release, Synapse no longer supports versions of SQLite before 3.11, and will refuse to start when configured to use an older version. Administrators are recommended to migrate their database to Postgres (see instructions [here](docs/postgres.md)). + +If your Synapse deployment uses workers, note that the reverse-proxy configurations for the `synapse.app.media_repository`, `synapse.app.federation_reader` and `synapse.app.event_creator` workers have changed, with the addition of a few paths (see the updated configurations [here](docs/workers.md#available-worker-applications)). Existing configurations will continue to work. + + +Improved Documentation +---------------------- + +- Fix endpoint documentation for the List Rooms admin API. ([\#6770](https://github.com/matrix-org/synapse/issues/6770)) + + +Synapse 1.9.0rc1 (2020-01-22) +============================= + +Features +-------- + +- Allow admin to create or modify a user. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5742](https://github.com/matrix-org/synapse/issues/5742)) +- Add new quarantine media admin APIs to quarantine by media ID or by user who uploaded the media. ([\#6681](https://github.com/matrix-org/synapse/issues/6681), [\#6756](https://github.com/matrix-org/synapse/issues/6756)) +- Add `org.matrix.e2e_cross_signing` to `unstable_features` in `/versions` as per [MSC1756](https://github.com/matrix-org/matrix-doc/pull/1756). ([\#6712](https://github.com/matrix-org/synapse/issues/6712)) +- Add a new admin API to list and filter rooms on the server. ([\#6720](https://github.com/matrix-org/synapse/issues/6720)) + + +Bugfixes +-------- + +- Correctly proxy HTTP errors due to API calls to remote group servers. ([\#6654](https://github.com/matrix-org/synapse/issues/6654)) +- Fix media repo admin APIs when using a media worker. ([\#6664](https://github.com/matrix-org/synapse/issues/6664)) +- Fix "CRITICAL" errors being logged when a request is received for a uri containing non-ascii characters. ([\#6682](https://github.com/matrix-org/synapse/issues/6682)) +- Fix a bug where we would assign a numeric user ID if somebody tried registering with an empty username. ([\#6690](https://github.com/matrix-org/synapse/issues/6690)) +- Fix `purge_room` admin API. ([\#6711](https://github.com/matrix-org/synapse/issues/6711)) +- Fix a bug causing Synapse to not always purge quiet rooms with a low `max_lifetime` in their message retention policies when running the automated purge jobs. ([\#6714](https://github.com/matrix-org/synapse/issues/6714)) +- Fix the `synapse_port_db` not correctly running background updates. Thanks @tadzik for reporting. ([\#6718](https://github.com/matrix-org/synapse/issues/6718)) +- Fix changing password via user admin API. ([\#6730](https://github.com/matrix-org/synapse/issues/6730)) +- Fix `/events/:event_id` deprecated API. ([\#6731](https://github.com/matrix-org/synapse/issues/6731)) +- Fix monthly active user limiting support for worker mode, fixes [#4639](https://github.com/matrix-org/synapse/issues/4639). ([\#6742](https://github.com/matrix-org/synapse/issues/6742)) +- Fix bug when setting `account_validity` to an empty block in the config. Thanks to @Sorunome for reporting. ([\#6747](https://github.com/matrix-org/synapse/issues/6747)) +- Fix `AttributeError: 'NoneType' object has no attribute 'get'` in `hash_password` when configuration has an empty `password_config`. Contributed by @ivilata. ([\#6753](https://github.com/matrix-org/synapse/issues/6753)) +- Fix the `docker-compose.yaml` overriding the entire `/etc` folder of the container. Contributed by Fabian Meyer. ([\#6656](https://github.com/matrix-org/synapse/issues/6656)) + + +Improved Documentation +---------------------- + +- Fix a typo in the configuration example for purge jobs in the sample configuration file. ([\#6621](https://github.com/matrix-org/synapse/issues/6621)) +- Add complete documentation of the message retention policies support. ([\#6624](https://github.com/matrix-org/synapse/issues/6624), [\#6665](https://github.com/matrix-org/synapse/issues/6665)) +- Add some helpful tips about changelog entries to the GitHub pull request template. ([\#6663](https://github.com/matrix-org/synapse/issues/6663)) +- Clarify the `account_validity` and `email` sections of the sample configuration. ([\#6685](https://github.com/matrix-org/synapse/issues/6685)) +- Add more endpoints to the documentation for Synapse workers. ([\#6698](https://github.com/matrix-org/synapse/issues/6698)) + + +Deprecations and Removals +------------------------- + +- Synapse no longer supports versions of SQLite before 3.11, and will refuse to start when configured to use an older version. Administrators are recommended to migrate their database to Postgres (see instructions [here](docs/postgres.md)). ([\#6675](https://github.com/matrix-org/synapse/issues/6675)) + + +Internal Changes +---------------- + +- Add `local_current_membership` table for tracking local user membership state in rooms. ([\#6655](https://github.com/matrix-org/synapse/issues/6655), [\#6728](https://github.com/matrix-org/synapse/issues/6728)) +- Port `synapse.replication.tcp` to async/await. ([\#6666](https://github.com/matrix-org/synapse/issues/6666)) +- Fixup `synapse.replication` to pass mypy checks. ([\#6667](https://github.com/matrix-org/synapse/issues/6667)) +- Allow `additional_resources` to implement `IResource` directly. ([\#6686](https://github.com/matrix-org/synapse/issues/6686)) +- Allow REST endpoint implementations to raise a `RedirectException`, which will redirect the user's browser to a given location. ([\#6687](https://github.com/matrix-org/synapse/issues/6687)) +- Updates and extensions to the module API. ([\#6688](https://github.com/matrix-org/synapse/issues/6688)) +- Updates to the SAML mapping provider API. ([\#6689](https://github.com/matrix-org/synapse/issues/6689), [\#6723](https://github.com/matrix-org/synapse/issues/6723)) +- Remove redundant `RegistrationError` class. ([\#6691](https://github.com/matrix-org/synapse/issues/6691)) +- Don't block processing of incoming EDUs behind processing PDUs in the same transaction. ([\#6697](https://github.com/matrix-org/synapse/issues/6697)) +- Remove duplicate check for the `session` query parameter on the `/auth/xxx/fallback/web` Client-Server endpoint. ([\#6702](https://github.com/matrix-org/synapse/issues/6702)) +- Attempt to retry sending a transaction when we detect a remote server has come back online, rather than waiting for a transaction to be triggered by new data. ([\#6706](https://github.com/matrix-org/synapse/issues/6706)) +- Add `StateMap` type alias to simplify types. ([\#6715](https://github.com/matrix-org/synapse/issues/6715)) +- Add a `DeltaState` to track changes to be made to current state during event persistence. ([\#6716](https://github.com/matrix-org/synapse/issues/6716)) +- Add more logging around message retention policies support. ([\#6717](https://github.com/matrix-org/synapse/issues/6717)) +- When processing a SAML response, log the assertions for easier configuration. ([\#6724](https://github.com/matrix-org/synapse/issues/6724)) +- Fixup `synapse.rest` to pass mypy. ([\#6732](https://github.com/matrix-org/synapse/issues/6732), [\#6764](https://github.com/matrix-org/synapse/issues/6764)) +- Fixup `synapse.api` to pass mypy. ([\#6733](https://github.com/matrix-org/synapse/issues/6733)) +- Allow streaming cache 'invalidate all' to workers. ([\#6749](https://github.com/matrix-org/synapse/issues/6749)) +- Remove unused CI docker compose files. ([\#6754](https://github.com/matrix-org/synapse/issues/6754)) + + +Synapse 1.8.0 (2020-01-09) +========================== + +**WARNING**: As of this release Synapse will refuse to start if the `log_file` config option is specified. Support for the option was removed in v1.3.0. + + +Bugfixes +-------- + +- Fix `GET` request on `/_synapse/admin/v2/users` endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#6563](https://github.com/matrix-org/synapse/issues/6563)) +- Fix incorrect signing of responses from the key server implementation. ([\#6657](https://github.com/matrix-org/synapse/issues/6657)) + + +Synapse 1.8.0rc1 (2020-01-07) +============================= + +Features +-------- + +- Add v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)). ([\#6349](https://github.com/matrix-org/synapse/issues/6349)) +- Add a develop script to generate full SQL schemas. ([\#6394](https://github.com/matrix-org/synapse/issues/6394)) +- Add custom SAML username mapping functionality through an external provider plugin. ([\#6411](https://github.com/matrix-org/synapse/issues/6411)) +- Automatically delete empty groups/communities. ([\#6453](https://github.com/matrix-org/synapse/issues/6453)) +- Add option `limit_profile_requests_to_users_who_share_rooms` to prevent requirement of a local user sharing a room with another user to query their profile information. ([\#6523](https://github.com/matrix-org/synapse/issues/6523)) +- Add an `export_signing_key` script to extract the public part of signing keys when rotating them. ([\#6546](https://github.com/matrix-org/synapse/issues/6546)) +- Add experimental config option to specify multiple databases. ([\#6580](https://github.com/matrix-org/synapse/issues/6580)) +- Raise an error if someone tries to use the `log_file` config option. ([\#6626](https://github.com/matrix-org/synapse/issues/6626)) + + +Bugfixes +-------- + +- Prevent redacted events from being returned during message search. ([\#6377](https://github.com/matrix-org/synapse/issues/6377), [\#6522](https://github.com/matrix-org/synapse/issues/6522)) +- Prevent error on trying to search a upgraded room when the server is not in the predecessor room. ([\#6385](https://github.com/matrix-org/synapse/issues/6385)) +- Improve performance of looking up cross-signing keys. ([\#6486](https://github.com/matrix-org/synapse/issues/6486)) +- Fix race which occasionally caused deleted devices to reappear. ([\#6514](https://github.com/matrix-org/synapse/issues/6514)) +- Fix missing row in `device_max_stream_id` that could cause unable to decrypt errors after server restart. ([\#6555](https://github.com/matrix-org/synapse/issues/6555)) +- Fix a bug which meant that we did not send systemd notifications on startup if acme was enabled. ([\#6571](https://github.com/matrix-org/synapse/issues/6571)) +- Fix exception when fetching the `matrix.org:ed25519:auto` key. ([\#6625](https://github.com/matrix-org/synapse/issues/6625)) +- Fix bug where a moderator upgraded a room and became an admin in the new room. ([\#6633](https://github.com/matrix-org/synapse/issues/6633)) +- Fix an error which was thrown by the `PresenceHandler` `_on_shutdown` handler. ([\#6640](https://github.com/matrix-org/synapse/issues/6640)) +- Fix exceptions in the synchrotron worker log when events are rejected. ([\#6645](https://github.com/matrix-org/synapse/issues/6645)) +- Ensure that upgraded rooms are removed from the directory. ([\#6648](https://github.com/matrix-org/synapse/issues/6648)) +- Fix a bug causing Synapse not to fetch missing events when it believes it has every event in the room. ([\#6652](https://github.com/matrix-org/synapse/issues/6652)) + + +Improved Documentation +---------------------- + +- Document the Room Shutdown Admin API. ([\#6541](https://github.com/matrix-org/synapse/issues/6541)) +- Reword sections of [docs/federate.md](docs/federate.md) that explained delegation at time of Synapse 1.0 transition. ([\#6601](https://github.com/matrix-org/synapse/issues/6601)) +- Added the section 'Configuration' in [docs/turn-howto.md](docs/turn-howto.md). ([\#6614](https://github.com/matrix-org/synapse/issues/6614)) + + +Deprecations and Removals +------------------------- + +- Remove redundant code from event authorisation implementation. ([\#6502](https://github.com/matrix-org/synapse/issues/6502)) +- Remove unused, undocumented `/_matrix/content` API. ([\#6628](https://github.com/matrix-org/synapse/issues/6628)) + + +Internal Changes +---------------- + +- Add *experimental* support for multiple physical databases and split out state storage to separate data store. ([\#6245](https://github.com/matrix-org/synapse/issues/6245), [\#6510](https://github.com/matrix-org/synapse/issues/6510), [\#6511](https://github.com/matrix-org/synapse/issues/6511), [\#6513](https://github.com/matrix-org/synapse/issues/6513), [\#6564](https://github.com/matrix-org/synapse/issues/6564), [\#6565](https://github.com/matrix-org/synapse/issues/6565)) +- Port sections of code base to async/await. ([\#6496](https://github.com/matrix-org/synapse/issues/6496), [\#6504](https://github.com/matrix-org/synapse/issues/6504), [\#6505](https://github.com/matrix-org/synapse/issues/6505), [\#6517](https://github.com/matrix-org/synapse/issues/6517), [\#6559](https://github.com/matrix-org/synapse/issues/6559), [\#6647](https://github.com/matrix-org/synapse/issues/6647), [\#6653](https://github.com/matrix-org/synapse/issues/6653)) +- Remove `SnapshotCache` in favour of `ResponseCache`. ([\#6506](https://github.com/matrix-org/synapse/issues/6506)) +- Silence mypy errors for files outside those specified. ([\#6512](https://github.com/matrix-org/synapse/issues/6512)) +- Clean up some logging when handling incoming events over federation. ([\#6515](https://github.com/matrix-org/synapse/issues/6515)) +- Test more folders against mypy. ([\#6534](https://github.com/matrix-org/synapse/issues/6534)) +- Update `mypy` to new version. ([\#6537](https://github.com/matrix-org/synapse/issues/6537)) +- Adjust the sytest blacklist for worker mode. ([\#6538](https://github.com/matrix-org/synapse/issues/6538)) +- Remove unused `get_pagination_rows` methods from `EventSource` classes. ([\#6557](https://github.com/matrix-org/synapse/issues/6557)) +- Clean up logs from the push notifier at startup. ([\#6558](https://github.com/matrix-org/synapse/issues/6558)) +- Improve diagnostics on database upgrade failure. ([\#6570](https://github.com/matrix-org/synapse/issues/6570)) +- Reduce the reconnect time when worker replication fails, to make it easier to catch up. ([\#6617](https://github.com/matrix-org/synapse/issues/6617)) +- Simplify http handling by removing redundant `SynapseRequestFactory`. ([\#6619](https://github.com/matrix-org/synapse/issues/6619)) +- Add a workaround for synapse raising exceptions when fetching the notary's own key from the notary. ([\#6620](https://github.com/matrix-org/synapse/issues/6620)) +- Automate generation of the sample log config. ([\#6627](https://github.com/matrix-org/synapse/issues/6627)) +- Simplify event creation code by removing redundant queries on the `event_reference_hashes` table. ([\#6629](https://github.com/matrix-org/synapse/issues/6629)) +- Fix errors when `frozen_dicts` are enabled. ([\#6642](https://github.com/matrix-org/synapse/issues/6642)) + + +Synapse 1.7.3 (2019-12-31) +========================== + +This release fixes a long-standing bug in the state resolution algorithm. + +Bugfixes +-------- + +- Fix exceptions caused by state resolution choking on malformed events. ([\#6608](https://github.com/matrix-org/synapse/issues/6608)) + + +Synapse 1.7.2 (2019-12-20) +========================== + +This release fixes some regressions introduced in Synapse 1.7.0 and 1.7.1. + +Bugfixes +-------- + +- Fix a regression introduced in Synapse 1.7.1 which caused errors when attempting to backfill rooms over federation. ([\#6576](https://github.com/matrix-org/synapse/issues/6576)) +- Fix a bug introduced in Synapse 1.7.0 which caused an error on startup when upgrading from versions before 1.3.0. ([\#6578](https://github.com/matrix-org/synapse/issues/6578)) + + +Synapse 1.7.1 (2019-12-18) +========================== + +This release includes several security fixes as well as a fix to a bug exposed by the security fixes. Administrators are encouraged to upgrade as soon as possible. + +Security updates +---------------- + +- Fix a bug which could cause room events to be incorrectly authorized using events from a different room. ([\#6501](https://github.com/matrix-org/synapse/issues/6501), [\#6503](https://github.com/matrix-org/synapse/issues/6503), [\#6521](https://github.com/matrix-org/synapse/issues/6521), [\#6524](https://github.com/matrix-org/synapse/issues/6524), [\#6530](https://github.com/matrix-org/synapse/issues/6530), [\#6531](https://github.com/matrix-org/synapse/issues/6531)) +- Fix a bug causing responses to the `/context` client endpoint to not use the pruned version of the event. ([\#6553](https://github.com/matrix-org/synapse/issues/6553)) +- Fix a cause of state resets in room versions 2 onwards. ([\#6556](https://github.com/matrix-org/synapse/issues/6556), [\#6560](https://github.com/matrix-org/synapse/issues/6560)) + +Bugfixes +-------- + +- Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs. ([\#6526](https://github.com/matrix-org/synapse/issues/6526), [\#6527](https://github.com/matrix-org/synapse/issues/6527)) + +Synapse 1.7.0 (2019-12-13) +========================== + +This release changes the default settings so that only local authenticated users can query the server's room directory. See the [upgrade notes](UPGRADE.rst#upgrading-to-v170) for details. + +Support for SQLite versions before 3.11 is now deprecated. A future release will refuse to start if used with an SQLite version before 3.11. + +Administrators are reminded that SQLite should not be used for production instances. Instructions for migrating to Postgres are available [here](docs/postgres.md). A future release of synapse will, by default, disable federation for servers using SQLite. + +No significant changes since 1.7.0rc2. + + +Synapse 1.7.0rc2 (2019-12-11) +============================= + +Bugfixes +-------- + +- Fix incorrect error message for invalid requests when setting user's avatar URL. ([\#6497](https://github.com/matrix-org/synapse/issues/6497)) +- Fix support for SQLite 3.7. ([\#6499](https://github.com/matrix-org/synapse/issues/6499)) +- Fix regression where sending email push would not work when using a pusher worker. ([\#6507](https://github.com/matrix-org/synapse/issues/6507), [\#6509](https://github.com/matrix-org/synapse/issues/6509)) + + +Synapse 1.7.0rc1 (2019-12-09) +============================= + +Features +-------- + +- Implement per-room message retention policies. ([\#5815](https://github.com/matrix-org/synapse/issues/5815), [\#6436](https://github.com/matrix-org/synapse/issues/6436)) +- Add etag and count fields to key backup endpoints to help clients guess if there are new keys. ([\#5858](https://github.com/matrix-org/synapse/issues/5858)) +- Add `/admin/v2/users` endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925)) +- Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. ([\#6119](https://github.com/matrix-org/synapse/issues/6119)) +- Implement the `/_matrix/federation/unstable/net.atleastfornow/state/<context>` API as drafted in MSC2314. ([\#6176](https://github.com/matrix-org/synapse/issues/6176)) +- Configure privacy-preserving settings by default for the room directory. ([\#6355](https://github.com/matrix-org/synapse/issues/6355)) +- Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). ([\#6409](https://github.com/matrix-org/synapse/issues/6409)) +- Add support for [MSC 2367](https://github.com/matrix-org/matrix-doc/pull/2367), which allows specifying a reason on all membership events. ([\#6434](https://github.com/matrix-org/synapse/issues/6434)) + + +Bugfixes +-------- + +- Transfer non-standard power levels on room upgrade. ([\#6237](https://github.com/matrix-org/synapse/issues/6237)) +- Fix error from the Pillow library when uploading RGBA images. ([\#6241](https://github.com/matrix-org/synapse/issues/6241)) +- Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests. ([\#6329](https://github.com/matrix-org/synapse/issues/6329)) +- Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. ([\#6332](https://github.com/matrix-org/synapse/issues/6332)) +- Prevent account data syncs getting lost across TCP replication. ([\#6333](https://github.com/matrix-org/synapse/issues/6333)) +- Fix bug: TypeError in `register_user()` while using LDAP auth module. ([\#6406](https://github.com/matrix-org/synapse/issues/6406)) +- Fix an intermittent exception when handling read-receipts. ([\#6408](https://github.com/matrix-org/synapse/issues/6408)) +- Fix broken guest registration when there are existing blocks of numeric user IDs. ([\#6420](https://github.com/matrix-org/synapse/issues/6420)) +- Fix startup error when http proxy is defined. ([\#6421](https://github.com/matrix-org/synapse/issues/6421)) +- Fix error when using synapse_port_db on a vanilla synapse db. ([\#6449](https://github.com/matrix-org/synapse/issues/6449)) +- Fix uploading multiple cross signing signatures for the same user. ([\#6451](https://github.com/matrix-org/synapse/issues/6451)) +- Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted. ([\#6462](https://github.com/matrix-org/synapse/issues/6462)) +- Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process. ([\#6470](https://github.com/matrix-org/synapse/issues/6470)) +- Improve sanity-checking when receiving events over federation. ([\#6472](https://github.com/matrix-org/synapse/issues/6472)) +- Fix inaccurate per-block Prometheus metrics. ([\#6491](https://github.com/matrix-org/synapse/issues/6491)) +- Fix small performance regression for sending invites. ([\#6493](https://github.com/matrix-org/synapse/issues/6493)) +- Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. ([\#6494](https://github.com/matrix-org/synapse/issues/6494)) + + +Improved Documentation +---------------------- + +- Update documentation and variables in user contributed systemd reference file. ([\#6369](https://github.com/matrix-org/synapse/issues/6369), [\#6490](https://github.com/matrix-org/synapse/issues/6490)) +- Fix link in the user directory documentation. ([\#6388](https://github.com/matrix-org/synapse/issues/6388)) +- Add build instructions to the docker readme. ([\#6390](https://github.com/matrix-org/synapse/issues/6390)) +- Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md. ([\#6443](https://github.com/matrix-org/synapse/issues/6443)) +- Write some docs for the quarantine_media api. ([\#6458](https://github.com/matrix-org/synapse/issues/6458)) +- Convert CONTRIBUTING.rst to markdown (among other small fixes). ([\#6461](https://github.com/matrix-org/synapse/issues/6461)) + + +Deprecations and Removals +------------------------- + +- Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925)) +- Remove fallback for federation with old servers which lack the /federation/v1/state_ids API. ([\#6488](https://github.com/matrix-org/synapse/issues/6488)) + + +Internal Changes +---------------- + +- Add benchmarks for structured logging and improve output performance. ([\#6266](https://github.com/matrix-org/synapse/issues/6266)) +- Improve the performance of outputting structured logging. ([\#6322](https://github.com/matrix-org/synapse/issues/6322)) +- Refactor some code in the event authentication path for clarity. ([\#6343](https://github.com/matrix-org/synapse/issues/6343), [\#6468](https://github.com/matrix-org/synapse/issues/6468), [\#6480](https://github.com/matrix-org/synapse/issues/6480)) +- Clean up some unnecessary quotation marks around the codebase. ([\#6362](https://github.com/matrix-org/synapse/issues/6362)) +- Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. ([\#6379](https://github.com/matrix-org/synapse/issues/6379)) +- Add a test scenario to make sure room history purges don't break `/messages` in the future. ([\#6392](https://github.com/matrix-org/synapse/issues/6392)) +- Clarifications for the email configuration settings. ([\#6423](https://github.com/matrix-org/synapse/issues/6423)) +- Add more tests to the blacklist when running in worker mode. ([\#6429](https://github.com/matrix-org/synapse/issues/6429)) +- Refactor data store layer to support multiple databases in the future. ([\#6454](https://github.com/matrix-org/synapse/issues/6454), [\#6464](https://github.com/matrix-org/synapse/issues/6464), [\#6469](https://github.com/matrix-org/synapse/issues/6469), [\#6487](https://github.com/matrix-org/synapse/issues/6487)) +- Port synapse.rest.client.v1 to async/await. ([\#6482](https://github.com/matrix-org/synapse/issues/6482)) +- Port synapse.rest.client.v2_alpha to async/await. ([\#6483](https://github.com/matrix-org/synapse/issues/6483)) +- Port SyncHandler to async/await. ([\#6484](https://github.com/matrix-org/synapse/issues/6484)) + +Synapse 1.6.1 (2019-11-28) +========================== + +Security updates +---------------- + +This release includes a security fix ([\#6426](https://github.com/matrix-org/synapse/issues/6426), below). Administrators are encouraged to upgrade as soon as possible. + +Bugfixes +-------- + +- Clean up local threepids from user on account deactivation. ([\#6426](https://github.com/matrix-org/synapse/issues/6426)) +- Fix startup error when http proxy is defined. ([\#6421](https://github.com/matrix-org/synapse/issues/6421)) + + +Synapse 1.6.0 (2019-11-26) +========================== + +Bugfixes +-------- + +- Fix phone home stats reporting. ([\#6418](https://github.com/matrix-org/synapse/issues/6418)) + + +Synapse 1.6.0rc2 (2019-11-25) +============================= + +Bugfixes +-------- + +- Fix a bug which could cause the background database update hander for event labels to get stuck in a loop raising exceptions. ([\#6407](https://github.com/matrix-org/synapse/issues/6407)) + + +Synapse 1.6.0rc1 (2019-11-20) +============================= + +Features +-------- + +- Add federation support for cross-signing. ([\#5727](https://github.com/matrix-org/synapse/issues/5727)) +- Increase default room version from 4 to 5, thereby enforcing server key validity period checks. ([\#6220](https://github.com/matrix-org/synapse/issues/6220)) +- Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars. ([\#6238](https://github.com/matrix-org/synapse/issues/6238)) +- Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)). ([\#6301](https://github.com/matrix-org/synapse/issues/6301), [\#6310](https://github.com/matrix-org/synapse/issues/6310), [\#6340](https://github.com/matrix-org/synapse/issues/6340)) + + +Bugfixes +-------- + +- Fix LruCache callback deduplication for Python 3.8. Contributed by @V02460. ([\#6213](https://github.com/matrix-org/synapse/issues/6213)) +- Remove a room from a server's public rooms list on room upgrade. ([\#6232](https://github.com/matrix-org/synapse/issues/6232), [\#6235](https://github.com/matrix-org/synapse/issues/6235)) +- Delete keys from key backup when deleting backup versions. ([\#6253](https://github.com/matrix-org/synapse/issues/6253)) +- Make notification of cross-signing signatures work with workers. ([\#6254](https://github.com/matrix-org/synapse/issues/6254)) +- Fix exception when remote servers attempt to join a room that they're not allowed to join. ([\#6278](https://github.com/matrix-org/synapse/issues/6278)) +- Prevent errors from appearing on Synapse startup if `git` is not installed. ([\#6284](https://github.com/matrix-org/synapse/issues/6284)) +- Appservice requests will no longer contain a double slash prefix when the appservice url provided ends in a slash. ([\#6306](https://github.com/matrix-org/synapse/issues/6306)) +- Fix `/purge_room` admin API. ([\#6307](https://github.com/matrix-org/synapse/issues/6307)) +- Fix the `hidden` field in the `devices` table for SQLite versions prior to 3.23.0. ([\#6313](https://github.com/matrix-org/synapse/issues/6313)) +- Fix bug which casued rejected events to be persisted with the wrong room state. ([\#6320](https://github.com/matrix-org/synapse/issues/6320)) +- Fix bug where `rc_login` ratelimiting would prematurely kick in. ([\#6335](https://github.com/matrix-org/synapse/issues/6335)) +- Prevent the server taking a long time to start up when guest registration is enabled. ([\#6338](https://github.com/matrix-org/synapse/issues/6338)) +- Fix bug where upgrading a guest account to a full user would fail when account validity is enabled. ([\#6359](https://github.com/matrix-org/synapse/issues/6359)) +- Fix `to_device` stream ID getting reset every time Synapse restarts, which had the potential to cause unable to decrypt errors. ([\#6363](https://github.com/matrix-org/synapse/issues/6363)) +- Fix permission denied error when trying to generate a config file with the docker image. ([\#6389](https://github.com/matrix-org/synapse/issues/6389)) + + +Improved Documentation +---------------------- + +- Contributor documentation now mentions script to run linters. ([\#6164](https://github.com/matrix-org/synapse/issues/6164)) +- Modify CAPTCHA_SETUP.md to update the terms `private key` and `public key` to `secret key` and `site key` respectively. Contributed by Yash Jipkate. ([\#6257](https://github.com/matrix-org/synapse/issues/6257)) +- Update `INSTALL.md` Email section to talk about `account_threepid_delegates`. ([\#6272](https://github.com/matrix-org/synapse/issues/6272)) +- Fix a small typo in `account_threepid_delegates` configuration option. ([\#6273](https://github.com/matrix-org/synapse/issues/6273)) + + +Internal Changes +---------------- + +- Add a CI job to test the `synapse_port_db` script. ([\#6140](https://github.com/matrix-org/synapse/issues/6140), [\#6276](https://github.com/matrix-org/synapse/issues/6276)) +- Convert EventContext to an attrs. ([\#6218](https://github.com/matrix-org/synapse/issues/6218)) +- Move `persist_events` out from main data store. ([\#6240](https://github.com/matrix-org/synapse/issues/6240), [\#6300](https://github.com/matrix-org/synapse/issues/6300)) +- Reduce verbosity of user/room stats. ([\#6250](https://github.com/matrix-org/synapse/issues/6250)) +- Reduce impact of debug logging. ([\#6251](https://github.com/matrix-org/synapse/issues/6251)) +- Expose some homeserver functionality to spam checkers. ([\#6259](https://github.com/matrix-org/synapse/issues/6259)) +- Change cache descriptors to always return deferreds. ([\#6263](https://github.com/matrix-org/synapse/issues/6263), [\#6291](https://github.com/matrix-org/synapse/issues/6291)) +- Fix incorrect comment regarding the functionality of an `if` statement. ([\#6269](https://github.com/matrix-org/synapse/issues/6269)) +- Update CI to run `isort` over the `scripts` and `scripts-dev` directories. ([\#6270](https://github.com/matrix-org/synapse/issues/6270)) +- Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated. ([\#6271](https://github.com/matrix-org/synapse/issues/6271), [\#6314](https://github.com/matrix-org/synapse/issues/6314)) +- Port replication http server endpoints to async/await. ([\#6274](https://github.com/matrix-org/synapse/issues/6274)) +- Port room rest handlers to async/await. ([\#6275](https://github.com/matrix-org/synapse/issues/6275)) +- Remove redundant CLI parameters on CI's `flake8` step. ([\#6277](https://github.com/matrix-org/synapse/issues/6277)) +- Port `federation_server.py` to async/await. ([\#6279](https://github.com/matrix-org/synapse/issues/6279)) +- Port receipt and read markers to async/wait. ([\#6280](https://github.com/matrix-org/synapse/issues/6280)) +- Split out state storage into separate data store. ([\#6294](https://github.com/matrix-org/synapse/issues/6294), [\#6295](https://github.com/matrix-org/synapse/issues/6295)) +- Refactor EventContext for clarity. ([\#6298](https://github.com/matrix-org/synapse/issues/6298)) +- Update the version of black used to 19.10b0. ([\#6304](https://github.com/matrix-org/synapse/issues/6304)) +- Add some documentation about worker replication. ([\#6305](https://github.com/matrix-org/synapse/issues/6305)) +- Move admin endpoints into separate files. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#6308](https://github.com/matrix-org/synapse/issues/6308)) +- Document the use of `lint.sh` for code style enforcement & extend it to run on specified paths only. ([\#6312](https://github.com/matrix-org/synapse/issues/6312)) +- Add optional python dependencies and dependant binary libraries to snapcraft packaging. ([\#6317](https://github.com/matrix-org/synapse/issues/6317)) +- Remove the dependency on psutil and replace functionality with the stdlib `resource` module. ([\#6318](https://github.com/matrix-org/synapse/issues/6318), [\#6336](https://github.com/matrix-org/synapse/issues/6336)) +- Improve documentation for EventContext fields. ([\#6319](https://github.com/matrix-org/synapse/issues/6319)) +- Add some checks that we aren't using state from rejected events. ([\#6330](https://github.com/matrix-org/synapse/issues/6330)) +- Add continuous integration for python 3.8. ([\#6341](https://github.com/matrix-org/synapse/issues/6341)) +- Correct spacing/case of various instances of the word "homeserver". ([\#6357](https://github.com/matrix-org/synapse/issues/6357)) +- Temporarily blacklist the failing unit test PurgeRoomTestCase.test_purge_room. ([\#6361](https://github.com/matrix-org/synapse/issues/6361)) + + +Synapse 1.5.1 (2019-11-06) +========================== + +Features +-------- + +- Limit the length of data returned by url previews, to prevent DoS attacks. ([\#6331](https://github.com/matrix-org/synapse/issues/6331), [\#6334](https://github.com/matrix-org/synapse/issues/6334)) + + +Synapse 1.5.0 (2019-10-29) +========================== + +Security updates +---------------- + +This release includes a security fix ([\#6262](https://github.com/matrix-org/synapse/issues/6262), below). Administrators are encouraged to upgrade as soon as possible. + +Bugfixes +-------- + +- Fix bug where room directory search was case sensitive. ([\#6268](https://github.com/matrix-org/synapse/issues/6268)) + + +Synapse 1.5.0rc2 (2019-10-28) +============================= + +Bugfixes +-------- + +- Update list of boolean columns in `synapse_port_db`. ([\#6247](https://github.com/matrix-org/synapse/issues/6247)) +- Fix /keys/query API on workers. ([\#6256](https://github.com/matrix-org/synapse/issues/6256)) +- Improve signature checking on some federation APIs. ([\#6262](https://github.com/matrix-org/synapse/issues/6262)) + + +Internal Changes +---------------- + +- Move schema delta files to the correct data store. ([\#6248](https://github.com/matrix-org/synapse/issues/6248)) +- Small performance improvement by removing repeated config lookups in room stats calculation. ([\#6255](https://github.com/matrix-org/synapse/issues/6255)) + + +Synapse 1.5.0rc1 (2019-10-24) +========================== + +Features +-------- + +- Improve quality of thumbnails for 1-bit/8-bit color palette images. ([\#2142](https://github.com/matrix-org/synapse/issues/2142)) +- Add ability to upload cross-signing signatures. ([\#5726](https://github.com/matrix-org/synapse/issues/5726)) +- Allow uploading of cross-signing keys. ([\#5769](https://github.com/matrix-org/synapse/issues/5769)) +- CAS login now provides a default display name for users if a `displayname_attribute` is set in the configuration file. ([\#6114](https://github.com/matrix-org/synapse/issues/6114)) +- Reject all pending invites for a user during deactivation. ([\#6125](https://github.com/matrix-org/synapse/issues/6125)) +- Add config option to suppress client side resource limit alerting. ([\#6173](https://github.com/matrix-org/synapse/issues/6173)) + + +Bugfixes +-------- + +- Return an HTTP 404 instead of 400 when requesting a filter by ID that is unknown to the server. Thanks to @krombel for contributing this! ([\#2380](https://github.com/matrix-org/synapse/issues/2380)) +- Fix a bug where users could be invited twice to the same group. ([\#3436](https://github.com/matrix-org/synapse/issues/3436)) +- Fix `/createRoom` failing with badly-formatted MXIDs in the invitee list. Thanks to @wener291! ([\#4088](https://github.com/matrix-org/synapse/issues/4088)) +- Make the `synapse_port_db` script create the right indexes on a new PostgreSQL database. ([\#6102](https://github.com/matrix-org/synapse/issues/6102), [\#6178](https://github.com/matrix-org/synapse/issues/6178), [\#6243](https://github.com/matrix-org/synapse/issues/6243)) +- Fix bug when uploading a large file: Synapse responds with `M_UNKNOWN` while it should be `M_TOO_LARGE` according to spec. Contributed by Anshul Angaria. ([\#6109](https://github.com/matrix-org/synapse/issues/6109)) +- Fix user push rules being deleted from a room when it is upgraded. ([\#6144](https://github.com/matrix-org/synapse/issues/6144)) +- Don't 500 when trying to exchange a revoked 3PID invite. ([\#6147](https://github.com/matrix-org/synapse/issues/6147)) +- Fix transferring notifications and tags when joining an upgraded room that is new to your server. ([\#6155](https://github.com/matrix-org/synapse/issues/6155)) +- Fix bug where guest account registration can wedge after restart. ([\#6161](https://github.com/matrix-org/synapse/issues/6161)) +- Fix monthly active user reaping when reserved users are specified. ([\#6168](https://github.com/matrix-org/synapse/issues/6168)) +- Fix `/federation/v1/state` endpoint not supporting newer room versions. ([\#6170](https://github.com/matrix-org/synapse/issues/6170)) +- Fix bug where we were updating censored events as bytes rather than text, occaisonally causing invalid JSON being inserted breaking APIs that attempted to fetch such events. ([\#6186](https://github.com/matrix-org/synapse/issues/6186)) +- Fix occasional missed updates in the room and user directories. ([\#6187](https://github.com/matrix-org/synapse/issues/6187)) +- Fix tracing of non-JSON APIs, `/media`, `/key` etc. ([\#6195](https://github.com/matrix-org/synapse/issues/6195)) +- Fix bug where presence would not get timed out correctly if a synchrotron worker is used and restarted. ([\#6212](https://github.com/matrix-org/synapse/issues/6212)) +- synapse_port_db: Add 2 additional BOOLEAN_COLUMNS to be able to convert from database schema v56. ([\#6216](https://github.com/matrix-org/synapse/issues/6216)) +- Fix a bug where the Synapse demo script blacklisted `::1` (ipv6 localhost) from receiving federation traffic. ([\#6229](https://github.com/matrix-org/synapse/issues/6229)) + + +Updates to the Docker image +--------------------------- + +- Fix logging getting lost for the docker image. ([\#6197](https://github.com/matrix-org/synapse/issues/6197)) + + +Internal Changes +---------------- + +- Update `user_filters` table to have a unique index, and non-null columns. Thanks to @pik for contributing this. ([\#1172](https://github.com/matrix-org/synapse/issues/1172), [\#6175](https://github.com/matrix-org/synapse/issues/6175), [\#6184](https://github.com/matrix-org/synapse/issues/6184)) +- Allow devices to be marked as hidden, for use by features such as cross-signing. + This adds a new field with a default value to the devices field in the database, + and so the database upgrade may take a long time depending on how many devices + are in the database. ([\#5759](https://github.com/matrix-org/synapse/issues/5759)) +- Move lookup-related functions from RoomMemberHandler to IdentityHandler. ([\#5978](https://github.com/matrix-org/synapse/issues/5978)) +- Improve performance of the public room list directory. ([\#6019](https://github.com/matrix-org/synapse/issues/6019), [\#6152](https://github.com/matrix-org/synapse/issues/6152), [\#6153](https://github.com/matrix-org/synapse/issues/6153), [\#6154](https://github.com/matrix-org/synapse/issues/6154)) +- Edit header dicts docstrings in `SimpleHttpClient` to note that `str` or `bytes` can be passed as header keys. ([\#6077](https://github.com/matrix-org/synapse/issues/6077)) +- Add snapcraft packaging information. Contributed by @devec0. ([\#6084](https://github.com/matrix-org/synapse/issues/6084), [\#6191](https://github.com/matrix-org/synapse/issues/6191)) +- Kill off half-implemented password-reset via sms. ([\#6101](https://github.com/matrix-org/synapse/issues/6101)) +- Remove `get_user_by_req` opentracing span and add some tags. ([\#6108](https://github.com/matrix-org/synapse/issues/6108)) +- Drop some unused database tables. ([\#6115](https://github.com/matrix-org/synapse/issues/6115)) +- Add env var to turn on tracking of log context changes. ([\#6127](https://github.com/matrix-org/synapse/issues/6127)) +- Refactor configuration loading to allow better typechecking. ([\#6137](https://github.com/matrix-org/synapse/issues/6137)) +- Log responder when responding to media request. ([\#6139](https://github.com/matrix-org/synapse/issues/6139)) +- Improve performance of `find_next_generated_user_id` DB query. ([\#6148](https://github.com/matrix-org/synapse/issues/6148)) +- Expand type-checking on modules imported by `synapse.config`. ([\#6150](https://github.com/matrix-org/synapse/issues/6150)) +- Use Postgres ANY for selecting many values. ([\#6156](https://github.com/matrix-org/synapse/issues/6156)) +- Add more caching to `_get_joined_users_from_context` DB query. ([\#6159](https://github.com/matrix-org/synapse/issues/6159)) +- Add some metrics on the federation sender. ([\#6160](https://github.com/matrix-org/synapse/issues/6160)) +- Add some logging to the rooms stats updates, to try to track down a flaky test. ([\#6167](https://github.com/matrix-org/synapse/issues/6167)) +- Remove unused `timeout` parameter from `_get_public_room_list`. ([\#6179](https://github.com/matrix-org/synapse/issues/6179)) +- Reject (accidental) attempts to insert bytes into postgres tables. ([\#6186](https://github.com/matrix-org/synapse/issues/6186)) +- Make `version` optional in body of `PUT /room_keys/version/{version}`, since it's redundant. ([\#6189](https://github.com/matrix-org/synapse/issues/6189)) +- Make storage layer responsible for adding device names to key, rather than the handler. ([\#6193](https://github.com/matrix-org/synapse/issues/6193)) +- Port `synapse.rest.admin` module to use async/await. ([\#6196](https://github.com/matrix-org/synapse/issues/6196)) +- Enforce that all boolean configuration values are lowercase in CI. ([\#6203](https://github.com/matrix-org/synapse/issues/6203)) +- Remove some unused event-auth code. ([\#6214](https://github.com/matrix-org/synapse/issues/6214)) +- Remove `Auth.check` method. ([\#6217](https://github.com/matrix-org/synapse/issues/6217)) +- Remove `format_tap.py` script in favour of a perl reimplementation in Sytest's repo. ([\#6219](https://github.com/matrix-org/synapse/issues/6219)) +- Refactor storage layer in preparation to support having multiple databases. ([\#6231](https://github.com/matrix-org/synapse/issues/6231)) +- Remove some extra quotation marks across the codebase. ([\#6236](https://github.com/matrix-org/synapse/issues/6236)) + + +Synapse 1.4.1 (2019-10-18) +========================== + +No changes since 1.4.1rc1. + + +Synapse 1.4.1rc1 (2019-10-17) +============================= + +Bugfixes +-------- + +- Fix bug where redacted events were sometimes incorrectly censored in the database, breaking APIs that attempted to fetch such events. ([\#6185](https://github.com/matrix-org/synapse/issues/6185), [5b0e9948](https://github.com/matrix-org/synapse/commit/5b0e9948eaae801643e594b5abc8ee4b10bd194e)) + +Synapse 1.4.0 (2019-10-03) +========================== + +Bugfixes +-------- + +- Redact `client_secret` in server logs. ([\#6158](https://github.com/matrix-org/synapse/issues/6158)) + + +Synapse 1.4.0rc2 (2019-10-02) +============================= + +Bugfixes +-------- + +- Fix bug in background update that adds last seen information to the `devices` table, and improve its performance on Postgres. ([\#6135](https://github.com/matrix-org/synapse/issues/6135)) +- Fix bad performance of censoring redactions background task. ([\#6141](https://github.com/matrix-org/synapse/issues/6141)) +- Fix fetching censored redactions from DB, which caused APIs like initial sync to fail if it tried to include the censored redaction. ([\#6145](https://github.com/matrix-org/synapse/issues/6145)) +- Fix exceptions when storing large retry intervals for down remote servers. ([\#6146](https://github.com/matrix-org/synapse/issues/6146)) + + +Internal Changes +---------------- + +- Fix up sample config entry for `redaction_retention_period` option. ([\#6117](https://github.com/matrix-org/synapse/issues/6117)) + + +Synapse 1.4.0rc1 (2019-09-26) +============================= + +Note that this release includes significant changes around 3pid +verification. Administrators are reminded to review the [upgrade notes](UPGRADE.rst#upgrading-to-v140). + +Features +-------- + +- Changes to 3pid verification: + - Add the ability to send registration emails from the homeserver rather than delegating to an identity server. ([\#5835](https://github.com/matrix-org/synapse/issues/5835), [\#5940](https://github.com/matrix-org/synapse/issues/5940), [\#5993](https://github.com/matrix-org/synapse/issues/5993), [\#5994](https://github.com/matrix-org/synapse/issues/5994), [\#5868](https://github.com/matrix-org/synapse/issues/5868)) + - Replace `trust_identity_server_for_password_resets` config option with `account_threepid_delegates`, and make the `id_server` parameteter optional on `*/requestToken` endpoints, as per [MSC2263](https://github.com/matrix-org/matrix-doc/pull/2263). ([\#5876](https://github.com/matrix-org/synapse/issues/5876), [\#5969](https://github.com/matrix-org/synapse/issues/5969), [\#6028](https://github.com/matrix-org/synapse/issues/6028)) + - Switch to using the v2 Identity Service `/lookup` API where available, with fallback to v1. (Implements [MSC2134](https://github.com/matrix-org/matrix-doc/pull/2134) plus `id_access_token authentication` for v2 Identity Service APIs from [MSC2140](https://github.com/matrix-org/matrix-doc/pull/2140)). ([\#5897](https://github.com/matrix-org/synapse/issues/5897)) + - Remove `bind_email` and `bind_msisdn` parameters from `/register` ala [MSC2140](https://github.com/matrix-org/matrix-doc/pull/2140). ([\#5964](https://github.com/matrix-org/synapse/issues/5964)) + - Add `m.id_access_token` to `unstable_features` in `/versions` as per [MSC2264](https://github.com/matrix-org/matrix-doc/pull/2264). ([\#5974](https://github.com/matrix-org/synapse/issues/5974)) + - Use the v2 Identity Service API for 3PID invites. ([\#5979](https://github.com/matrix-org/synapse/issues/5979)) + - Add `POST /_matrix/client/unstable/account/3pid/unbind` endpoint from [MSC2140](https://github.com/matrix-org/matrix-doc/pull/2140) for unbinding a 3PID from an identity server without removing it from the homeserver user account. ([\#5980](https://github.com/matrix-org/synapse/issues/5980), [\#6062](https://github.com/matrix-org/synapse/issues/6062)) + - Use `account_threepid_delegate.email` and `account_threepid_delegate.msisdn` for validating threepid sessions. ([\#6011](https://github.com/matrix-org/synapse/issues/6011)) + - Allow homeserver to handle or delegate email validation when adding an email to a user's account. ([\#6042](https://github.com/matrix-org/synapse/issues/6042)) + - Implement new Client Server API endpoints `/account/3pid/add` and `/account/3pid/bind` as per [MSC2290](https://github.com/matrix-org/matrix-doc/pull/2290). ([\#6043](https://github.com/matrix-org/synapse/issues/6043)) + - Add an unstable feature flag for separate add/bind 3pid APIs. ([\#6044](https://github.com/matrix-org/synapse/issues/6044)) + - Remove `bind` parameter from Client Server POST `/account` endpoint as per [MSC2290](https://github.com/matrix-org/matrix-doc/pull/2290/). ([\#6067](https://github.com/matrix-org/synapse/issues/6067)) + - Add `POST /add_threepid/msisdn/submit_token` endpoint for proxying submitToken on an `account_threepid_handler`. ([\#6078](https://github.com/matrix-org/synapse/issues/6078)) + - Add `submit_url` response parameter to `*/msisdn/requestToken` endpoints. ([\#6079](https://github.com/matrix-org/synapse/issues/6079)) + - Add `m.require_identity_server` flag to /version's unstable_features. ([\#5972](https://github.com/matrix-org/synapse/issues/5972)) +- Enhancements to OpenTracing support: + - Make OpenTracing work in worker mode. ([\#5771](https://github.com/matrix-org/synapse/issues/5771)) + - Pass OpenTracing contexts between servers when transmitting EDUs. ([\#5852](https://github.com/matrix-org/synapse/issues/5852)) + - OpenTracing for device list updates. ([\#5853](https://github.com/matrix-org/synapse/issues/5853)) + - Add a tag recording a request's authenticated entity and corresponding servlet in OpenTracing. ([\#5856](https://github.com/matrix-org/synapse/issues/5856)) + - Add minimum OpenTracing for client servlets. ([\#5983](https://github.com/matrix-org/synapse/issues/5983)) + - Check at setup that OpenTracing is installed if it's enabled in the config. ([\#5985](https://github.com/matrix-org/synapse/issues/5985)) + - Trace replication send times. ([\#5986](https://github.com/matrix-org/synapse/issues/5986)) + - Include missing OpenTracing contexts in outbout replication requests. ([\#5982](https://github.com/matrix-org/synapse/issues/5982)) + - Fix sending of EDUs when OpenTracing is enabled with an empty whitelist. ([\#5984](https://github.com/matrix-org/synapse/issues/5984)) + - Fix invalid references to None while OpenTracing if the log context slips. ([\#5988](https://github.com/matrix-org/synapse/issues/5988), [\#5991](https://github.com/matrix-org/synapse/issues/5991)) + - OpenTracing for room and e2e keys. ([\#5855](https://github.com/matrix-org/synapse/issues/5855)) + - Add OpenTracing span over HTTP push processing. ([\#6003](https://github.com/matrix-org/synapse/issues/6003)) +- Add an admin API to purge old rooms from the database. ([\#5845](https://github.com/matrix-org/synapse/issues/5845)) +- Retry well-known lookups if we have recently seen a valid well-known record for the server. ([\#5850](https://github.com/matrix-org/synapse/issues/5850)) +- Add support for filtered room-directory search requests over federation ([MSC2197](https://github.com/matrix-org/matrix-doc/pull/2197), in order to allow upcoming room directory query performance improvements. ([\#5859](https://github.com/matrix-org/synapse/issues/5859)) +- Correctly retry all hosts returned from SRV when we fail to connect. ([\#5864](https://github.com/matrix-org/synapse/issues/5864)) +- Add admin API endpoint for setting whether or not a user is a server administrator. ([\#5878](https://github.com/matrix-org/synapse/issues/5878)) +- Enable cleaning up extremities with dummy events by default to prevent undue build up of forward extremities. ([\#5884](https://github.com/matrix-org/synapse/issues/5884)) +- Add config option to sign remote key query responses with a separate key. ([\#5895](https://github.com/matrix-org/synapse/issues/5895)) +- Add support for config templating. ([\#5900](https://github.com/matrix-org/synapse/issues/5900)) +- Users with the type of "support" or "bot" are no longer required to consent. ([\#5902](https://github.com/matrix-org/synapse/issues/5902)) +- Let synctl accept a directory of config files. ([\#5904](https://github.com/matrix-org/synapse/issues/5904)) +- Increase max display name size to 256. ([\#5906](https://github.com/matrix-org/synapse/issues/5906)) +- Add admin API endpoint for getting whether or not a user is a server administrator. ([\#5914](https://github.com/matrix-org/synapse/issues/5914)) +- Redact events in the database that have been redacted for a week. ([\#5934](https://github.com/matrix-org/synapse/issues/5934)) +- New prometheus metrics: + - `synapse_federation_known_servers`: represents the total number of servers your server knows about (i.e. is in rooms with), including itself. Enable by setting `metrics_flags.known_servers` to True in the configuration.([\#5981](https://github.com/matrix-org/synapse/issues/5981)) + - `synapse_build_info`: exposes the Python version, OS version, and Synapse version of the running server. ([\#6005](https://github.com/matrix-org/synapse/issues/6005)) +- Give appropriate exit codes when synctl fails. ([\#5992](https://github.com/matrix-org/synapse/issues/5992)) +- Apply the federation blacklist to requests to identity servers. ([\#6000](https://github.com/matrix-org/synapse/issues/6000)) +- Add `report_stats_endpoint` option to configure where stats are reported to, if enabled. Contributed by @Sorunome. ([\#6012](https://github.com/matrix-org/synapse/issues/6012)) +- Add config option to increase ratelimits for room admins redacting messages. ([\#6015](https://github.com/matrix-org/synapse/issues/6015)) +- Stop sending federation transactions to servers which have been down for a long time. ([\#6026](https://github.com/matrix-org/synapse/issues/6026)) +- Make the process for mapping SAML2 users to matrix IDs more flexible. ([\#6037](https://github.com/matrix-org/synapse/issues/6037)) +- Return a clearer error message when a timeout occurs when attempting to contact an identity server. ([\#6073](https://github.com/matrix-org/synapse/issues/6073)) +- Prevent password reset's submit_token endpoint from accepting trailing slashes. ([\#6074](https://github.com/matrix-org/synapse/issues/6074)) +- Return 403 on `/register/available` if registration has been disabled. ([\#6082](https://github.com/matrix-org/synapse/issues/6082)) +- Explicitly log when a homeserver does not have the `trusted_key_servers` config field configured. ([\#6090](https://github.com/matrix-org/synapse/issues/6090)) +- Add support for pruning old rows in `user_ips` table. ([\#6098](https://github.com/matrix-org/synapse/issues/6098)) + +Bugfixes +-------- + +- Don't create broken room when `power_level_content_override.users` does not contain `creator_id`. ([\#5633](https://github.com/matrix-org/synapse/issues/5633)) +- Fix database index so that different backup versions can have the same sessions. ([\#5857](https://github.com/matrix-org/synapse/issues/5857)) +- Fix Synapse looking for config options `password_reset_failure_template` and `password_reset_success_template`, when they are actually `password_reset_template_failure_html`, `password_reset_template_success_html`. ([\#5863](https://github.com/matrix-org/synapse/issues/5863)) +- Fix stack overflow when recovering an appservice which had an outage. ([\#5885](https://github.com/matrix-org/synapse/issues/5885)) +- Fix error message which referred to `public_base_url` instead of `public_baseurl`. Thanks to @aaronraimist for the fix! ([\#5909](https://github.com/matrix-org/synapse/issues/5909)) +- Fix 404 for thumbnail download when `dynamic_thumbnails` is `false` and the thumbnail was dynamically generated. Fix reported by rkfg. ([\#5915](https://github.com/matrix-org/synapse/issues/5915)) +- Fix a cache-invalidation bug for worker-based deployments. ([\#5920](https://github.com/matrix-org/synapse/issues/5920)) +- Fix admin API for listing media in a room not being available with an external media repo. ([\#5966](https://github.com/matrix-org/synapse/issues/5966)) +- Fix list media admin API always returning an error. ([\#5967](https://github.com/matrix-org/synapse/issues/5967)) +- Fix room and user stats tracking. ([\#5971](https://github.com/matrix-org/synapse/issues/5971), [\#5998](https://github.com/matrix-org/synapse/issues/5998), [\#6029](https://github.com/matrix-org/synapse/issues/6029)) +- Return a `M_MISSING_PARAM` if `sid` is not provided to `/account/3pid`. ([\#5995](https://github.com/matrix-org/synapse/issues/5995)) +- `federation_certificate_verification_whitelist` now will not cause `TypeErrors` to be raised (a regression in 1.3). Additionally, it now supports internationalised domain names in their non-canonical representation. ([\#5996](https://github.com/matrix-org/synapse/issues/5996)) +- Only count real users when checking for auto-creation of auto-join room. ([\#6004](https://github.com/matrix-org/synapse/issues/6004)) +- Ensure support users can be registered even if MAU limit is reached. ([\#6020](https://github.com/matrix-org/synapse/issues/6020)) +- Fix bug where login error was shown incorrectly on SSO fallback login. ([\#6024](https://github.com/matrix-org/synapse/issues/6024)) +- Fix bug in calculating the federation retry backoff period. ([\#6025](https://github.com/matrix-org/synapse/issues/6025)) +- Prevent exceptions being logged when extremity-cleanup events fail due to lack of user consent to the terms of service. ([\#6053](https://github.com/matrix-org/synapse/issues/6053)) +- Remove POST method from password-reset `submit_token` endpoint until we implement `submit_url` functionality. ([\#6056](https://github.com/matrix-org/synapse/issues/6056)) +- Fix logcontext spam on non-Linux platforms. ([\#6059](https://github.com/matrix-org/synapse/issues/6059)) +- Ensure query parameters in email validation links are URL-encoded. ([\#6063](https://github.com/matrix-org/synapse/issues/6063)) +- Fix a bug which caused SAML attribute maps to be overridden by defaults. ([\#6069](https://github.com/matrix-org/synapse/issues/6069)) +- Fix the logged number of updated items for the `users_set_deactivated_flag` background update. ([\#6092](https://github.com/matrix-org/synapse/issues/6092)) +- Add `sid` to `next_link` for email validation. ([\#6097](https://github.com/matrix-org/synapse/issues/6097)) +- Threepid validity checks on msisdns should not be dependent on `threepid_behaviour_email`. ([\#6104](https://github.com/matrix-org/synapse/issues/6104)) +- Ensure that servers which are not configured to support email address verification do not offer it in the registration flows. ([\#6107](https://github.com/matrix-org/synapse/issues/6107)) + + +Updates to the Docker image +--------------------------- + +- Avoid changing `UID/GID` if they are already correct. ([\#5970](https://github.com/matrix-org/synapse/issues/5970)) +- Provide `SYNAPSE_WORKER` envvar to specify python module. ([\#6058](https://github.com/matrix-org/synapse/issues/6058)) + + +Improved Documentation +---------------------- + +- Convert documentation to markdown (from rst) ([\#5849](https://github.com/matrix-org/synapse/issues/5849)) +- Update `INSTALL.md` to say that Python 2 is no longer supported. ([\#5953](https://github.com/matrix-org/synapse/issues/5953)) +- Add developer documentation for using SAML2. ([\#6032](https://github.com/matrix-org/synapse/issues/6032)) +- Add some notes on rolling back to v1.3.1. ([\#6049](https://github.com/matrix-org/synapse/issues/6049)) +- Update the upgrade notes. ([\#6050](https://github.com/matrix-org/synapse/issues/6050)) + + +Deprecations and Removals +------------------------- + +- Remove shared-secret registration from `/_matrix/client/r0/register` endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5877](https://github.com/matrix-org/synapse/issues/5877)) +- Deprecate the `trusted_third_party_id_servers` option. ([\#5875](https://github.com/matrix-org/synapse/issues/5875)) + + +Internal Changes +---------------- + +- Lay the groundwork for structured logging output. ([\#5680](https://github.com/matrix-org/synapse/issues/5680)) +- Retry well-known lookup before the cache expires, giving a grace period where the remote well-known can be down but we still use the old result. ([\#5844](https://github.com/matrix-org/synapse/issues/5844)) +- Remove log line for debugging issue #5407. ([\#5860](https://github.com/matrix-org/synapse/issues/5860)) +- Refactor the Appservice scheduler code. ([\#5886](https://github.com/matrix-org/synapse/issues/5886)) +- Compatibility with v2 Identity Service APIs other than /lookup. ([\#5892](https://github.com/matrix-org/synapse/issues/5892), [\#6013](https://github.com/matrix-org/synapse/issues/6013)) +- Stop populating some unused tables. ([\#5893](https://github.com/matrix-org/synapse/issues/5893), [\#6047](https://github.com/matrix-org/synapse/issues/6047)) +- Add missing index on `users_in_public_rooms` to improve the performance of directory queries. ([\#5894](https://github.com/matrix-org/synapse/issues/5894)) +- Improve the logging when we have an error when fetching signing keys. ([\#5896](https://github.com/matrix-org/synapse/issues/5896)) +- Add support for database engine-specific schema deltas, based on file extension. ([\#5911](https://github.com/matrix-org/synapse/issues/5911)) +- Update Buildkite pipeline to use plugins instead of buildkite-agent commands. ([\#5922](https://github.com/matrix-org/synapse/issues/5922)) +- Add link in sample config to the logging config schema. ([\#5926](https://github.com/matrix-org/synapse/issues/5926)) +- Remove unnecessary parentheses in return statements. ([\#5931](https://github.com/matrix-org/synapse/issues/5931)) +- Remove unused `jenkins/prepare_sytest.sh` file. ([\#5938](https://github.com/matrix-org/synapse/issues/5938)) +- Move Buildkite pipeline config to the pipelines repo. ([\#5943](https://github.com/matrix-org/synapse/issues/5943)) +- Remove unnecessary return statements in the codebase which were the result of a regex run. ([\#5962](https://github.com/matrix-org/synapse/issues/5962)) +- Remove left-over methods from v1 registration API. ([\#5963](https://github.com/matrix-org/synapse/issues/5963)) +- Cleanup event auth type initialisation. ([\#5975](https://github.com/matrix-org/synapse/issues/5975)) +- Clean up dependency checking at setup. ([\#5989](https://github.com/matrix-org/synapse/issues/5989)) +- Update OpenTracing docs to use the unified `trace` method. ([\#5776](https://github.com/matrix-org/synapse/issues/5776)) +- Small refactor of function arguments and docstrings in` RoomMemberHandler`. ([\#6009](https://github.com/matrix-org/synapse/issues/6009)) +- Remove unused `origin` argument on `FederationHandler.add_display_name_to_third_party_invite`. ([\#6010](https://github.com/matrix-org/synapse/issues/6010)) +- Add a `failure_ts` column to the `destinations` database table. ([\#6016](https://github.com/matrix-org/synapse/issues/6016), [\#6072](https://github.com/matrix-org/synapse/issues/6072)) +- Clean up some code in the retry logic. ([\#6017](https://github.com/matrix-org/synapse/issues/6017)) +- Fix the structured logging tests stomping on the global log configuration for subsequent tests. ([\#6023](https://github.com/matrix-org/synapse/issues/6023)) +- Clean up the sample config for SAML authentication. ([\#6064](https://github.com/matrix-org/synapse/issues/6064)) +- Change mailer logging to reflect Synapse doesn't just do chat notifications by email now. ([\#6075](https://github.com/matrix-org/synapse/issues/6075)) +- Move last-seen info into devices table. ([\#6089](https://github.com/matrix-org/synapse/issues/6089)) +- Remove unused parameter to `get_user_id_by_threepid`. ([\#6099](https://github.com/matrix-org/synapse/issues/6099)) +- Refactor the user-interactive auth handling. ([\#6105](https://github.com/matrix-org/synapse/issues/6105)) +- Refactor code for calculating registration flows. ([\#6106](https://github.com/matrix-org/synapse/issues/6106)) + + Synapse 1.3.1 (2019-08-17) ========================== diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..062413e925 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,268 @@ +# Contributing code to Synapse + +Everyone is welcome to contribute code to [matrix.org +projects](https://github.com/matrix-org), provided that they are willing to +license their contributions under the same license as the project itself. We +follow a simple 'inbound=outbound' model for contributions: the act of +submitting an 'inbound' contribution means that the contributor agrees to +license the code under the same terms as the project's overall 'outbound' +license - in our case, this is almost always Apache Software License v2 (see +[LICENSE](LICENSE)). + +## How to contribute + +The preferred and easiest way to contribute changes is to fork the relevant +project on github, and then [create a pull request]( +https://help.github.com/articles/using-pull-requests/) to ask us to pull your +changes into our repo. + +Some other points to follow: + + * Please base your changes on the `develop` branch. + + * Please follow the [code style requirements](#code-style). + + * Please include a [changelog entry](#changelog) with each PR. + + * Please [sign off](#sign-off) your contribution. + + * Please keep an eye on the pull request for feedback from the [continuous + integration system](#continuous-integration-and-testing) and try to fix any + errors that come up. + + * If you need to [update your PR](#updating-your-pull-request), just add new + commits to your branch rather than rebasing. + +## Code style + +Synapse's code style is documented [here](docs/code_style.md). Please follow +it, including the conventions for the [sample configuration +file](docs/code_style.md#configuration-file-format). + +Many of the conventions are enforced by scripts which are run as part of the +[continuous integration system](#continuous-integration-and-testing). To help +check if you have followed the code style, you can run `scripts-dev/lint.sh` +locally. You'll need python 3.6 or later, and to install a number of tools: + +``` +# Install the dependencies +pip install -U black flake8 flake8-comprehensions isort + +# Run the linter script +./scripts-dev/lint.sh +``` + +**Note that the script does not just test/check, but also reformats code, so you +may wish to ensure any new code is committed first**. + +By default, this script checks all files and can take some time; if you alter +only certain files, you might wish to specify paths as arguments to reduce the +run-time: + +``` +./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder +``` + +Before pushing new changes, ensure they don't produce linting errors. Commit any +files that were corrected. + +Please ensure your changes match the cosmetic style of the existing project, +and **never** mix cosmetic and functional changes in the same commit, as it +makes it horribly hard to review otherwise. + +## Changelog + +All changes, even minor ones, need a corresponding changelog / newsfragment +entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier). + +To create a changelog entry, make a new file in the `changelog.d` directory named +in the format of `PRnumber.type`. The type can be one of the following: + +* `feature` +* `bugfix` +* `docker` (for updates to the Docker image) +* `doc` (for updates to the documentation) +* `removal` (also used for deprecations) +* `misc` (for internal-only changes) + +This file will become part of our [changelog]( +https://github.com/matrix-org/synapse/blob/master/CHANGES.md) at the next +release, so the content of the file should be a short description of your +change in the same style as the rest of the changelog. The file can contain Markdown +formatting, and should end with a full stop (.) or an exclamation mark (!) for +consistency. + +Adding credits to the changelog is encouraged, we value your +contributions and would like to have you shouted out in the release notes! + +For example, a fix in PR #1234 would have its changelog entry in +`changelog.d/1234.bugfix`, and contain content like: + +> The security levels of Florbs are now validated when received +> via the `/federation/florb` endpoint. Contributed by Jane Matrix. + +If there are multiple pull requests involved in a single bugfix/feature/etc, +then the content for each `changelog.d` file should be the same. Towncrier will +merge the matching files together into a single changelog entry when we come to +release. + +### How do I know what to call the changelog file before I create the PR? + +Obviously, you don't know if you should call your newsfile +`1234.bugfix` or `5678.bugfix` until you create the PR, which leads to a +chicken-and-egg problem. + +There are two options for solving this: + + 1. Open the PR without a changelog file, see what number you got, and *then* + add the changelog file to your branch (see [Updating your pull + request](#updating-your-pull-request)), or: + + 1. Look at the [list of all + issues/PRs](https://github.com/matrix-org/synapse/issues?q=), add one to the + highest number you see, and quickly open the PR before somebody else claims + your number. + + [This + script](https://github.com/richvdh/scripts/blob/master/next_github_number.sh) + might be helpful if you find yourself doing this a lot. + +Sorry, we know it's a bit fiddly, but it's *really* helpful for us when we come +to put together a release! + +### Debian changelog + +Changes which affect the debian packaging files (in `debian`) are an +exception to the rule that all changes require a `changelog.d` file. + +In this case, you will need to add an entry to the debian changelog for the +next release. For this, run the following command: + +``` +dch +``` + +This will make up a new version number (if there isn't already an unreleased +version in flight), and open an editor where you can add a new changelog entry. +(Our release process will ensure that the version number and maintainer name is +corrected for the release.) + +If your change affects both the debian packaging *and* files outside the debian +directory, you will need both a regular newsfragment *and* an entry in the +debian changelog. (Though typically such changes should be submitted as two +separate pull requests.) + +## Sign off + +In order to have a concrete record that your contribution is intentional +and you agree to license it under the same terms as the project's license, we've adopted the +same lightweight approach that the Linux Kernel +[submitting patches process]( +https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>), +[Docker](https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other +projects use: the DCO (Developer Certificate of Origin: +http://developercertificate.org/). This is a simple declaration that you wrote +the contribution or otherwise have the right to contribute it to Matrix: + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +660 York Street, Suite 102, +San Francisco, CA 94110 USA + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +If you agree to this for your contribution, then all that's needed is to +include the line in your commit or pull request comment: + +``` +Signed-off-by: Your Name <your@email.example.org> +``` + +We accept contributions under a legally identifiable name, such as +your name on government documentation or common-law names (names +claimed by legitimate usage or repute). Unfortunately, we cannot +accept anonymous contributions at this time. + +Git allows you to add this signoff automatically when using the `-s` +flag to `git commit`, which uses the name and email set in your +`user.name` and `user.email` git configs. + +## Continuous integration and testing + +[Buildkite](https://buildkite.com/matrix-dot-org/synapse) will automatically +run a series of checks and tests against any PR which is opened against the +project; if your change breaks the build, this will be shown in GitHub, with +links to the build results. If your build fails, please try to fix the errors +and update your branch. + +To run unit tests in a local development environment, you can use: + +- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) + for SQLite-backed Synapse on Python 3.5. +- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. +- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 + (requires a running local PostgreSQL with access to create databases). +- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 + (requires Docker). Entirely self-contained, recommended if you don't want to + set up PostgreSQL yourself. + +Docker images are available for running the integration tests (SyTest) locally, +see the [documentation in the SyTest repo]( +https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more +information. + +## Updating your pull request + +If you decide to make changes to your pull request - perhaps to address issues +raised in a review, or to fix problems highlighted by [continuous +integration](#continuous-integration-and-testing) - just add new commits to your +branch, and push to GitHub. The pull request will automatically be updated. + +Please **avoid** rebasing your branch, especially once the PR has been +reviewed: doing so makes it very difficult for a reviewer to see what has +changed since a previous review. + +## Notes for maintainers on merging PRs etc + +There are some notes for those with commit access to the project on how we +manage git [here](docs/dev/git.md). + +## Conclusion + +That's it! Matrix is a very open and collaborative project as you might expect +given our obsession with open communication. If we're going to successfully +matrix together all the fragmented communication technologies out there we are +reliant on contributions and collaboration from the community to do so. So +please get involved - and we hope you have as much fun hacking on Matrix as we +do! diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst deleted file mode 100644 index 620dc88ce2..0000000000 --- a/CONTRIBUTING.rst +++ /dev/null @@ -1,198 +0,0 @@ -Contributing code to Matrix -=========================== - -Everyone is welcome to contribute code to Matrix -(https://github.com/matrix-org), provided that they are willing to license -their contributions under the same license as the project itself. We follow a -simple 'inbound=outbound' model for contributions: the act of submitting an -'inbound' contribution means that the contributor agrees to license the code -under the same terms as the project's overall 'outbound' license - in our -case, this is almost always Apache Software License v2 (see LICENSE). - -How to contribute -~~~~~~~~~~~~~~~~~ - -The preferred and easiest way to contribute changes to Matrix is to fork the -relevant project on github, and then create a pull request to ask us to pull -your changes into our repo -(https://help.github.com/articles/using-pull-requests/) - -**The single biggest thing you need to know is: please base your changes on -the develop branch - /not/ master.** - -We use the master branch to track the most recent release, so that folks who -blindly clone the repo and automatically check out master get something that -works. Develop is the unstable branch where all the development actually -happens: the workflow is that contributors should fork the develop branch to -make a 'feature' branch for a particular contribution, and then make a pull -request to merge this back into the matrix.org 'official' develop branch. We -use github's pull request workflow to review the contribution, and either ask -you to make any refinements needed or merge it and make them ourselves. The -changes will then land on master when we next do a release. - -We use `Buildkite <https://buildkite.com/matrix-dot-org/synapse>`_ for -continuous integration. Buildkite builds need to be authorised by a -maintainer. If your change breaks the build, this will be shown in GitHub, so -please keep an eye on the pull request for feedback. - -To run unit tests in a local development environment, you can use: - -- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) - for SQLite-backed Synapse on Python 3.5. -- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. -- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 - (requires a running local PostgreSQL with access to create databases). -- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 - (requires Docker). Entirely self-contained, recommended if you don't want to - set up PostgreSQL yourself. - -Docker images are available for running the integration tests (SyTest) locally, -see the `documentation in the SyTest repo -<https://github.com/matrix-org/sytest/blob/develop/docker/README.md>`_ for more -information. - -Code style -~~~~~~~~~~ - -All Matrix projects have a well-defined code-style - and sometimes we've even -got as far as documenting it... For instance, synapse's code style doc lives -at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md. - -Please ensure your changes match the cosmetic style of the existing project, -and **never** mix cosmetic and functional changes in the same commit, as it -makes it horribly hard to review otherwise. - -Changelog -~~~~~~~~~ - -All changes, even minor ones, need a corresponding changelog / newsfragment -entry. These are managed by Towncrier -(https://github.com/hawkowl/towncrier). - -To create a changelog entry, make a new file in the ``changelog.d`` file named -in the format of ``PRnumber.type``. The type can be one of the following: - -* ``feature``. -* ``bugfix``. -* ``docker`` (for updates to the Docker image). -* ``doc`` (for updates to the documentation). -* ``removal`` (also used for deprecations). -* ``misc`` (for internal-only changes). - -The content of the file is your changelog entry, which should be a short -description of your change in the same style as the rest of our `changelog -<https://github.com/matrix-org/synapse/blob/master/CHANGES.md>`_. The file can -contain Markdown formatting, and should end with a full stop ('.') for -consistency. - -Adding credits to the changelog is encouraged, we value your -contributions and would like to have you shouted out in the release notes! - -For example, a fix in PR #1234 would have its changelog entry in -``changelog.d/1234.bugfix``, and contain content like "The security levels of -Florbs are now validated when recieved over federation. Contributed by Jane -Matrix.". - -Debian changelog ----------------- - -Changes which affect the debian packaging files (in ``debian``) are an -exception. - -In this case, you will need to add an entry to the debian changelog for the -next release. For this, run the following command:: - - dch - -This will make up a new version number (if there isn't already an unreleased -version in flight), and open an editor where you can add a new changelog entry. -(Our release process will ensure that the version number and maintainer name is -corrected for the release.) - -If your change affects both the debian packaging *and* files outside the debian -directory, you will need both a regular newsfragment *and* an entry in the -debian changelog. (Though typically such changes should be submitted as two -separate pull requests.) - -Attribution -~~~~~~~~~~~ - -Everyone who contributes anything to Matrix is welcome to be listed in the -AUTHORS.rst file for the project in question. Please feel free to include a -change to AUTHORS.rst in your pull request to list yourself and a short -description of the area(s) you've worked on. Also, we sometimes have swag to -give away to contributors - if you feel that Matrix-branded apparel is missing -from your life, please mail us your shipping address to matrix at matrix.org and -we'll try to fix it :) - -Sign off -~~~~~~~~ - -In order to have a concrete record that your contribution is intentional -and you agree to license it under the same terms as the project's license, we've adopted the -same lightweight approach that the Linux Kernel -`submitting patches process <https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>`_, Docker -(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other -projects use: the DCO (Developer Certificate of Origin: -http://developercertificate.org/). This is a simple declaration that you wrote -the contribution or otherwise have the right to contribute it to Matrix:: - - Developer Certificate of Origin - Version 1.1 - - Copyright (C) 2004, 2006 The Linux Foundation and its contributors. - 660 York Street, Suite 102, - San Francisco, CA 94110 USA - - Everyone is permitted to copy and distribute verbatim copies of this - license document, but changing it is not allowed. - - Developer's Certificate of Origin 1.1 - - By making a contribution to this project, I certify that: - - (a) The contribution was created in whole or in part by me and I - have the right to submit it under the open source license - indicated in the file; or - - (b) The contribution is based upon previous work that, to the best - of my knowledge, is covered under an appropriate open source - license and I have the right under that license to submit that - work with modifications, whether created in whole or in part - by me, under the same open source license (unless I am - permitted to submit under a different license), as indicated - in the file; or - - (c) The contribution was provided directly to me by some other - person who certified (a), (b) or (c) and I have not modified - it. - - (d) I understand and agree that this project and the contribution - are public and that a record of the contribution (including all - personal information I submit with it, including my sign-off) is - maintained indefinitely and may be redistributed consistent with - this project or the open source license(s) involved. - -If you agree to this for your contribution, then all that's needed is to -include the line in your commit or pull request comment:: - - Signed-off-by: Your Name <your@email.example.org> - -We accept contributions under a legally identifiable name, such as -your name on government documentation or common-law names (names -claimed by legitimate usage or repute). Unfortunately, we cannot -accept anonymous contributions at this time. - -Git allows you to add this signoff automatically when using the ``-s`` -flag to ``git commit``, which uses the name and email set in your -``user.name`` and ``user.email`` git configs. - -Conclusion -~~~~~~~~~~ - -That's it! Matrix is a very open and collaborative project as you might expect -given our obsession with open communication. If we're going to successfully -matrix together all the fragmented communication technologies out there we are -reliant on contributions and collaboration from the community to do so. So -please get involved - and we hope you have as much fun hacking on Matrix as we -do! diff --git a/INSTALL.md b/INSTALL.md index 3eb979c362..ef80a26c3f 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -2,7 +2,6 @@ - [Installing Synapse](#installing-synapse) - [Installing from source](#installing-from-source) - [Platform-Specific Instructions](#platform-specific-instructions) - - [Troubleshooting Installation](#troubleshooting-installation) - [Prebuilt packages](#prebuilt-packages) - [Setting up Synapse](#setting-up-synapse) - [TLS certificates](#tls-certificates) @@ -10,6 +9,7 @@ - [Registering a user](#registering-a-user) - [Setting up a TURN server](#setting-up-a-turn-server) - [URL previews](#url-previews) +- [Troubleshooting Installation](#troubleshooting-installation) # Choosing your server name @@ -36,7 +36,7 @@ that your email address is probably `user@example.com` rather than System requirements: - POSIX-compliant system (tested on Linux & OS X) -- Python 3.5, 3.6, or 3.7 +- Python 3.5.2 or later, up to Python 3.8. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org Synapse is written in Python but some of the libraries it uses are written in @@ -70,7 +70,7 @@ pip install -U matrix-synapse ``` Before you can start Synapse, you will need to generate a configuration -file. To do this, run (in your virtualenv, as before):: +file. To do this, run (in your virtualenv, as before): ``` cd ~/synapse @@ -84,22 +84,24 @@ python -m synapse.app.homeserver \ ... substituting an appropriate value for `--server-name`. This command will generate you a config file that you can then customise, but it will -also generate a set of keys for you. These keys will allow your Home Server to -identify itself to other Home Servers, so don't lose or delete them. It would be +also generate a set of keys for you. These keys will allow your homeserver to +identify itself to other homeserver, so don't lose or delete them. It would be wise to back them up somewhere safe. (If, for whatever reason, you do need to -change your Home Server's keys, you may find that other Home Servers have the +change your homeserver's keys, you may find that other homeserver have the old key cached. If you update the signing key, you should change the name of the key in the `<server name>.signing.key` file (the second word) to something different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) -for more information on key management.) +for more information on key management). To actually run your new homeserver, pick a working directory for Synapse to -run (e.g. `~/synapse`), and:: +run (e.g. `~/synapse`), and: - cd ~/synapse - source env/bin/activate - synctl start +``` +cd ~/synapse +source env/bin/activate +synctl start +``` ### Platform-Specific Instructions @@ -109,8 +111,8 @@ Installing prerequisites on Ubuntu or Debian: ``` sudo apt-get install build-essential python3-dev libffi-dev \ - python-pip python-setuptools sqlite3 \ - libssl-dev python-virtualenv libjpeg-dev libxslt1-dev + python3-pip python3-setuptools sqlite3 \ + libssl-dev virtualenv libjpeg-dev libxslt1-dev ``` #### ArchLinux @@ -124,18 +126,32 @@ sudo pacman -S base-devel python python-pip \ #### CentOS/Fedora -Installing prerequisites on CentOS 7 or Fedora 25: +Installing prerequisites on CentOS 8 or Fedora>26: + +``` +sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ + libwebp-devel tk-devel redhat-rpm-config \ + python3-virtualenv libffi-devel openssl-devel +sudo dnf groupinstall "Development Tools" +``` + +Installing prerequisites on CentOS 7 or Fedora<=25: ``` sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ - python-virtualenv libffi-devel openssl-devel + python3-virtualenv libffi-devel openssl-devel sudo yum groupinstall "Development Tools" ``` -#### Mac OS X +Note that Synapse does not support versions of SQLite before 3.11, and CentOS 7 +uses SQLite 3.7. You may be able to work around this by installing a more +recent SQLite version, but it is recommended that you instead use a Postgres +database: see [docs/postgres.md](docs/postgres.md). + +#### macOS -Installing prerequisites on Mac OS X: +Installing prerequisites on macOS: ``` xcode-select --install @@ -144,6 +160,14 @@ sudo pip install virtualenv brew install pkg-config libffi ``` +On macOS Catalina (10.15) you may need to explicitly install OpenSSL +via brew and inform `pip` about it so that `psycopg2` builds: + +``` +brew install openssl@1.1 +export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/ +``` + #### OpenSUSE Installing prerequisites on openSUSE: @@ -156,35 +180,41 @@ sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ #### OpenBSD -Installing prerequisites on OpenBSD: +A port of Synapse is available under `net/synapse`. The filesystem +underlying the homeserver directory (defaults to `/var/synapse`) has to be +mounted with `wxallowed` (cf. `mount(8)`), so creating a separate filesystem +and mounting it to `/var/synapse` should be taken into consideration. + +To be able to build Synapse's dependency on python the `WRKOBJDIR` +(cf. `bsd.port.mk(5)`) for building python, too, needs to be on a filesystem +mounted with `wxallowed` (cf. `mount(8)`). + +Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a +default OpenBSD installation is mounted with `wxallowed`): ``` -doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ - libxslt jpeg +doas mkdir /usr/local/pobj_wxallowed ``` -There is currently no port for OpenBSD. Additionally, OpenBSD's security -settings require a slightly more difficult installation process. +Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are +configured in `/etc/mk.conf`: -XXX: I suspect this is out of date. +``` +doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed +``` + +Setting the `WRKOBJDIR` for building python: -1. Create a new directory in `/usr/local` called `_synapse`. Also, create a - new user called `_synapse` and set that directory as the new user's home. - This is required because, by default, OpenBSD only allows binaries which need - write and execute permissions on the same memory space to be run from - `/usr/local`. -2. `su` to the new `_synapse` user and change to their home directory. -3. Create a new virtualenv: `virtualenv -p python2.7 ~/.synapse` -4. Source the virtualenv configuration located at - `/usr/local/_synapse/.synapse/bin/activate`. This is done in `ksh` by - using the `.` command, rather than `bash`'s `source`. -5. Optionally, use `pip` to install `lxml`, which Synapse needs to parse - webpages for their titles. -6. Use `pip` to install this repository: `pip install matrix-synapse` -7. Optionally, change `_synapse`'s shell to `/bin/false` to reduce the - chance of a compromised Synapse server being used to take over your box. +``` +echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf +``` -After this, you may proceed with the rest of the install directions. +Building Synapse: + +``` +cd /usr/ports/net/synapse +make install +``` #### Windows @@ -195,45 +225,6 @@ be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server for Windows Server. -### Troubleshooting Installation - -XXX a bunch of this is no longer relevant. - -Synapse requires pip 8 or later, so if your OS provides too old a version you -may need to manually upgrade it:: - - sudo pip install --upgrade pip - -Installing may fail with `Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)`. -You can fix this by manually upgrading pip and virtualenv:: - - sudo pip install --upgrade virtualenv - -You can next rerun `virtualenv -p python3 synapse` to update the virtual env. - -Installing may fail during installing virtualenv with `InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.` -You can fix this by manually installing ndg-httpsclient:: - - pip install --upgrade ndg-httpsclient - -Installing may fail with `mock requires setuptools>=17.1. Aborting installation`. -You can fix this by upgrading setuptools:: - - pip install --upgrade setuptools - -If pip crashes mid-installation for reason (e.g. lost terminal), pip may -refuse to run until you remove the temporary installation directory it -created. To reset the installation:: - - rm -rf /tmp/pip_install_matrix - -pip seems to leak *lots* of memory during installation. For instance, a Linux -host with 512MB of RAM may run out of memory whilst installing Twisted. If this -happens, you will have to individually install the dependencies which are -failing, e.g.:: - - pip install twisted - ## Prebuilt packages As an alternative to installing from source, prebuilt packages are available @@ -292,7 +283,7 @@ For `buster` and `sid`, Synapse is available in the Debian repositories and it should be possible to install it with simply: ``` - sudo apt install matrix-synapse +sudo apt install matrix-synapse ``` There is also a version of `matrix-synapse` in `stretch-backports`. Please see @@ -349,13 +340,34 @@ sudo pip uninstall py-bcrypt sudo pip install py-bcrypt ``` +### Void Linux + +Synapse can be found in the void repositories as 'synapse': + +``` +xbps-install -Su +xbps-install -S synapse +``` + ### FreeBSD Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from: - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean` - - Packages: `pkg install py27-matrix-synapse` + - Packages: `pkg install py37-matrix-synapse` +### OpenBSD + +As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem +underlying the homeserver directory (defaults to `/var/synapse`) has to be +mounted with `wxallowed` (cf. `mount(8)`), so creating a separate filesystem +and mounting it to `/var/synapse` should be taken into consideration. + +Installing Synapse: + +``` +doas pkg_add synapse +``` ### NixOS @@ -368,15 +380,17 @@ Once you have installed synapse as above, you will need to configure it. ## TLS certificates -The default configuration exposes a single HTTP port: http://localhost:8008. It -is suitable for local testing, but for any practical use, you will either need -to enable a reverse proxy, or configure Synapse to expose an HTTPS port. +The default configuration exposes a single HTTP port on the local +interface: `http://localhost:8008`. It is suitable for local testing, +but for any practical use, you will need Synapse's APIs to be served +over HTTPS. -For information on using a reverse proxy, see +The recommended way to do so is to set up a reverse proxy on port +`8448`. You can find documentation on doing so in [docs/reverse_proxy.md](docs/reverse_proxy.md). -To configure Synapse to expose an HTTPS port, you will need to edit -`homeserver.yaml`, as follows: +Alternatively, you can configure Synapse to expose an HTTPS port. To do +so, you will need to edit `homeserver.yaml`, as follows: * First, under the `listeners` section, uncomment the configuration for the TLS-enabled listener. (Remove the hash sign (`#`) at the start of @@ -389,33 +403,39 @@ To configure Synapse to expose an HTTPS port, you will need to edit resources: - names: [client, federation] ``` + * You will also need to uncomment the `tls_certificate_path` and `tls_private_key_path` lines under the `TLS` section. You can either point these settings at an existing certificate and key, or you can enable Synapse's built-in ACME (Let's Encrypt) support. Instructions for having Synapse automatically provision and renew federation - certificates through ACME can be found at [ACME.md](docs/ACME.md). If you - are using your own certificate, be sure to use a `.pem` file that includes - the full certificate chain including any intermediate certificates (for - instance, if using certbot, use `fullchain.pem` as your certificate, not + certificates through ACME can be found at [ACME.md](docs/ACME.md). + Note that, as pointed out in that document, this feature will not + work with installs set up after November 2019. + + If you are using your own certificate, be sure to use a `.pem` file that + includes the full certificate chain including any intermediate certificates + (for instance, if using certbot, use `fullchain.pem` as your certificate, not `cert.pem`). For a more detailed guide to configuring your server for federation, see -[federate.md](docs/federate.md) +[federate.md](docs/federate.md). ## Email -It is desirable for Synapse to have the capability to send email. For example, -this is required to support the 'password reset' feature. +It is desirable for Synapse to have the capability to send email. This allows +Synapse to send password reset emails, send verifications when an email address +is added to a user's account, and send email notifications to users when they +receive new messages. To configure an SMTP server for Synapse, modify the configuration section -headed ``email``, and be sure to have at least the ``smtp_host``, ``smtp_port`` -and ``notif_from`` fields filled out. You may also need to set ``smtp_user``, -``smtp_pass``, and ``require_transport_security``. +headed `email`, and be sure to have at least the `smtp_host`, `smtp_port` +and `notif_from` fields filled out. You may also need to set `smtp_user`, +`smtp_pass`, and `require_transport_security`. -If Synapse is not configured with an SMTP server, password reset via email will - be disabled by default. +If email is not configured, password reset, registration and notifications via +email will be disabled. ## Registering a user @@ -446,7 +466,7 @@ on your server even if `enable_registration` is `false`. ## Setting up a TURN server For reliable VoIP calls to be routed via this homeserver, you MUST configure -a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details. +a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details. ## URL previews @@ -455,10 +475,24 @@ turn it on you must enable the `url_preview_enabled: True` config parameter and explicitly specify the IP ranges that Synapse is not allowed to spider for previewing in the `url_preview_ip_range_blacklist` configuration parameter. This is critical from a security perspective to stop arbitrary Matrix users -spidering 'internal' URLs on your network. At the very least we recommend that +spidering 'internal' URLs on your network. At the very least we recommend that your loopback and RFC1918 IP addresses are blacklisted. -This also requires the optional lxml and netaddr python dependencies to be -installed. This in turn requires the libxml2 library to be available - on +This also requires the optional `lxml` and `netaddr` python dependencies to be +installed. This in turn requires the `libxml2` library to be available - on Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for your OS. + +# Troubleshooting Installation + +`pip` seems to leak *lots* of memory during installation. For instance, a Linux +host with 512MB of RAM may run out of memory whilst installing Twisted. If this +happens, you will have to individually install the dependencies which are +failing, e.g.: + +``` +pip install twisted +``` + +If you have any other problems, feel free to ask in +[#synapse:matrix.org](https://matrix.to/#/#synapse:matrix.org). diff --git a/MANIFEST.in b/MANIFEST.in index 9c2902b8d2..120ce5b776 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,11 +8,12 @@ include demo/demo.tls.dh include demo/*.py include demo/*.sh -recursive-include synapse/storage/schema *.sql -recursive-include synapse/storage/schema *.sql.postgres -recursive-include synapse/storage/schema *.sql.sqlite -recursive-include synapse/storage/schema *.py -recursive-include synapse/storage/schema *.txt +recursive-include synapse/storage *.sql +recursive-include synapse/storage *.sql.postgres +recursive-include synapse/storage *.sql.sqlite +recursive-include synapse/storage *.py +recursive-include synapse/storage *.txt +recursive-include synapse/storage *.md recursive-include docs * recursive-include scripts * @@ -29,25 +30,24 @@ recursive-include synapse/static *.gif recursive-include synapse/static *.html recursive-include synapse/static *.js -exclude Dockerfile +exclude .codecov.yml +exclude .coveragerc exclude .dockerignore -exclude test_postgresql.sh exclude .editorconfig +exclude Dockerfile +exclude mypy.ini exclude sytest-blacklist +exclude test_postgresql.sh include pyproject.toml recursive-include changelog.d * prune .buildkite prune .circleci -prune .codecov.yml -prune .coveragerc prune .github +prune contrib prune debian prune demo/etc prune docker -prune mypy.ini +prune snap prune stubs - -exclude jenkins* -recursive-exclude jenkins *.sh diff --git a/README.rst b/README.rst index 2948fd0765..31d375d19b 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,11 @@ +================ +Synapse |shield| +================ + +.. |shield| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix + :alt: (get support on #synapse:matrix.org) + :target: https://matrix.to/#/#synapse:matrix.org + .. contents:: Introduction @@ -77,6 +85,17 @@ Thanks for using Matrix! [1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_. +Support +======= + +For support installing or managing Synapse, please join |room|_ (from a matrix.org +account if necessary) and ask questions there. We do not use GitHub issues for +support requests, only for bug reports and feature requests. + +.. |room| replace:: ``#synapse:matrix.org`` +.. _room: https://matrix.to/#/#synapse:matrix.org + + Synapse Installation ==================== @@ -248,7 +267,7 @@ First calculate the hash of the new password:: Confirm password: $2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx -Then update the `users` table in the database:: +Then update the ``users`` table in the database:: UPDATE users SET password_hash='$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' WHERE name='@test:test.com'; @@ -272,7 +291,7 @@ to install using pip and a virtualenv:: virtualenv -p python3 env source env/bin/activate - python -m pip install --no-use-pep517 -e .[all] + python -m pip install --no-use-pep517 -e ".[all]" This will run a process of downloading and installing all the needed dependencies into a virtual env. @@ -316,6 +335,9 @@ Building internal API documentation:: Troubleshooting =============== +Need help? Join our community support room on Matrix: +`#synapse:matrix.org <https://matrix.to/#/#synapse:matrix.org>`_ + Running out of File Handles --------------------------- @@ -381,3 +403,16 @@ indicate that your server is also issuing far more outgoing federation requests than can be accounted for by your users' activity, this is a likely cause. The misbehavior can be worked around by setting ``use_presence: false`` in the Synapse config file. + +People can't accept room invitations from me +-------------------------------------------- + +The typical failure mode here is that you send an invitation to someone +to join a room or direct chat, but when they go to accept it, they get an +error (typically along the lines of "Invalid signature"). They might see +something like the following in their logs:: + + 2019-09-11 19:32:04,271 - synapse.federation.transport.server - 288 - WARNING - GET-11752 - authenticate_request failed: 401: Invalid signature for server <server> with key ed25519:a_EqML: Unable to verify signature for <server> + +This is normally caused by a misconfiguration in your reverse-proxy. See +`<docs/reverse_proxy.md>`_ and double-check that your settings are correct. diff --git a/UPGRADE.rst b/UPGRADE.rst index 5aaf804902..3b5627e852 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -2,102 +2,446 @@ Upgrading Synapse ================= Before upgrading check if any special steps are required to upgrade from the -what you currently have installed to current version of synapse. The extra +version you currently have installed to the current version of Synapse. The extra instructions that may be required are listed later in this document. -1. If synapse was installed in a virtualenv then activate that virtualenv before - upgrading. If synapse is installed in a virtualenv in ``~/synapse/env`` then - run: +* If Synapse was installed using `prebuilt packages + <INSTALL.md#prebuilt-packages>`_, you will need to follow the normal process + for upgrading those packages. - .. code:: bash +* If Synapse was installed from source, then: - source ~/synapse/env/bin/activate + 1. Activate the virtualenv before upgrading. For example, if Synapse is + installed in a virtualenv in ``~/synapse/env`` then run: -2. If synapse was installed using pip then upgrade to the latest version by - running: + .. code:: bash - .. code:: bash + source ~/synapse/env/bin/activate - pip install --upgrade matrix-synapse[all] + 2. If Synapse was installed using pip then upgrade to the latest version by + running: - # restart synapse - synctl restart + .. code:: bash + pip install --upgrade matrix-synapse - If synapse was installed using git then upgrade to the latest version by - running: + If Synapse was installed using git then upgrade to the latest version by + running: - .. code:: bash + .. code:: bash - # Pull the latest version of the master branch. git pull + pip install --upgrade . - # Update synapse and its python dependencies. - pip install --upgrade .[all] + 3. Restart Synapse: - # restart synapse - ./synctl restart + .. code:: bash + ./synctl restart -To check whether your update was successful, you can check the Server header -returned by the Client-Server API: +To check whether your update was successful, you can check the running server +version with: .. code:: bash - # replace <host.name> with the hostname of your synapse homeserver. - # You may need to specify a port (eg, :8448) if your server is not - # configured on port 443. - curl -kv https://<host.name>/_matrix/client/versions 2>&1 | grep "Server:" + # you may need to replace 'localhost:8008' if synapse is not configured + # to listen on port 8008. + + curl http://localhost:8008/_synapse/admin/v1/server_version + +Rolling back to older versions +------------------------------ + +Rolling back to previous releases can be difficult, due to database schema +changes between releases. Where we have been able to test the rollback process, +this will be noted below. + +In general, you will need to undo any changes made during the upgrade process, +for example: + +* pip: + + .. code:: bash + + source env/bin/activate + # replace `1.3.0` accordingly: + pip install matrix-synapse==1.3.0 + +* Debian: + + .. code:: bash + + # replace `1.3.0` and `stretch` accordingly: + wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb + dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb + +Upgrading to v1.14.0 +==================== + +This version includes a database update which is run as part of the upgrade, +and which may take a couple of minutes in the case of a large server. Synapse +will not respond to HTTP requests while this update is taking place. + +Upgrading to v1.13.0 +==================== + +Incorrect database migration in old synapse versions +---------------------------------------------------- + +A bug was introduced in Synapse 1.4.0 which could cause the room directory to +be incomplete or empty if Synapse was upgraded directly from v1.2.1 or +earlier, to versions between v1.4.0 and v1.12.x. + +This will *not* be a problem for Synapse installations which were: + * created at v1.4.0 or later, + * upgraded via v1.3.x, or + * upgraded straight from v1.2.1 or earlier to v1.13.0 or later. + +If completeness of the room directory is a concern, installations which are +affected can be repaired as follows: + +1. Run the following sql from a `psql` or `sqlite3` console: + + .. code:: sql + + INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'current_state_events_membership'); + + INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); + +2. Restart synapse. + +New Single Sign-on HTML Templates +--------------------------------- + +New templates (``sso_auth_confirm.html``, ``sso_auth_success.html``, and +``sso_account_deactivated.html``) were added to Synapse. If your Synapse is +configured to use SSO and a custom ``sso_redirect_confirm_template_dir`` +configuration then these templates will need to be copied from +`synapse/res/templates <synapse/res/templates>`_ into that directory. + +Synapse SSO Plugins Method Deprecation +-------------------------------------- + +Plugins using the ``complete_sso_login`` method of +``synapse.module_api.ModuleApi`` should update to using the async/await +version ``complete_sso_login_async`` which includes additional checks. The +non-async version is considered deprecated. + +Rolling back to v1.12.4 after a failed upgrade +---------------------------------------------- + +v1.13.0 includes a lot of large changes. If something problematic occurs, you +may want to roll-back to a previous version of Synapse. Because v1.13.0 also +includes a new database schema version, reverting that version is also required +alongside the generic rollback instructions mentioned above. In short, to roll +back to v1.12.4 you need to: + +1. Stop the server +2. Decrease the schema version in the database: + + .. code:: sql + + UPDATE schema_version SET version = 57; + +3. Downgrade Synapse by following the instructions for your installation method + in the "Rolling back to older versions" section above. + + +Upgrading to v1.12.0 +==================== + +This version includes a database update which is run as part of the upgrade, +and which may take some time (several hours in the case of a large +server). Synapse will not respond to HTTP requests while this update is taking +place. + +This is only likely to be a problem in the case of a server which is +participating in many rooms. + +0. As with all upgrades, it is recommended that you have a recent backup of + your database which can be used for recovery in the event of any problems. + +1. As an initial check to see if you will be affected, you can try running the + following query from the `psql` or `sqlite3` console. It is safe to run it + while Synapse is still running. + + .. code:: sql + + SELECT MAX(q.v) FROM ( + SELECT ( + SELECT ej.json AS v + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 + ) FROM rooms WHERE rooms.room_version IS NULL + ) q; + + This query will take about the same amount of time as the upgrade process: ie, + if it takes 5 minutes, then it is likely that Synapse will be unresponsive for + 5 minutes during the upgrade. + + If you consider an outage of this duration to be acceptable, no further + action is necessary and you can simply start Synapse 1.12.0. + + If you would prefer to reduce the downtime, continue with the steps below. + +2. The easiest workaround for this issue is to manually + create a new index before upgrading. On PostgreSQL, his can be done as follows: + + .. code:: sql + + CREATE INDEX CONCURRENTLY tmp_upgrade_1_12_0_index + ON state_events(room_id) WHERE type = 'm.room.create'; + + The above query may take some time, but is also safe to run while Synapse is + running. + + We assume that no SQLite users have databases large enough to be + affected. If you *are* affected, you can run a similar query, omitting the + ``CONCURRENTLY`` keyword. Note however that this operation may in itself cause + Synapse to stop running for some time. Synapse admins are reminded that + `SQLite is not recommended for use outside a test + environment <https://github.com/matrix-org/synapse/blob/master/README.rst#using-postgresql>`_. + +3. Once the index has been created, the ``SELECT`` query in step 1 above should + complete quickly. It is therefore safe to upgrade to Synapse 1.12.0. + +4. Once Synapse 1.12.0 has successfully started and is responding to HTTP + requests, the temporary index can be removed: + + .. code:: sql + + DROP INDEX tmp_upgrade_1_12_0_index; + +Upgrading to v1.10.0 +==================== + +Synapse will now log a warning on start up if used with a PostgreSQL database +that has a non-recommended locale set. + +See `docs/postgres.md <docs/postgres.md>`_ for details. + + +Upgrading to v1.8.0 +=================== + +Specifying a ``log_file`` config option will now cause Synapse to refuse to +start, and should be replaced by with the ``log_config`` option. Support for +the ``log_file`` option was removed in v1.3.0 and has since had no effect. + + +Upgrading to v1.7.0 +=================== + +In an attempt to configure Synapse in a privacy preserving way, the default +behaviours of ``allow_public_rooms_without_auth`` and +``allow_public_rooms_over_federation`` have been inverted. This means that by +default, only authenticated users querying the Client/Server API will be able +to query the room directory, and relatedly that the server will not share +room directory information with other servers over federation. + +If your installation does not explicitly set these settings one way or the other +and you want either setting to be ``true`` then it will necessary to update +your homeserver configuration file accordingly. + +For more details on the surrounding context see our `explainer +<https://matrix.org/blog/2019/11/09/avoiding-unwelcome-visitors-on-private-matrix-servers>`_. + + +Upgrading to v1.5.0 +=================== + +This release includes a database migration which may take several minutes to +complete if there are a large number (more than a million or so) of entries in +the ``devices`` table. This is only likely to a be a problem on very large +installations. + Upgrading to v1.4.0 =================== -Config options --------------- +New custom templates +-------------------- -**Note: Registration by email address or phone number will not work in this release unless -some config options are changed from their defaults.** +If you have configured a custom template directory with the +``email.template_dir`` option, be aware that there are new templates regarding +registration and threepid management (see below) that must be included. + +* ``registration.html`` and ``registration.txt`` +* ``registration_success.html`` and ``registration_failure.html`` +* ``add_threepid.html`` and ``add_threepid.txt`` +* ``add_threepid_failure.html`` and ``add_threepid_success.html`` + +Synapse will expect these files to exist inside the configured template +directory, and **will fail to start** if they are absent. +To view the default templates, see `synapse/res/templates +<https://github.com/matrix-org/synapse/tree/master/synapse/res/templates>`_. -This is due to Synapse v1.4.0 now defaulting to sending registration and password reset tokens -itself. This is for security reasons as well as putting less reliance on identity servers. -However, currently Synapse only supports sending emails, and does not have support for -phone-based password reset or account registration. If Synapse is configured to handle these on -its own, phone-based password resets and registration will be disabled. For Synapse to send -emails, the ``email`` block of the config must be filled out. If not, then password resets and -registration via email will be disabled entirely. +3pid verification changes +------------------------- + +**Note: As of this release, users will be unable to add phone numbers or email +addresses to their accounts, without changes to the Synapse configuration. This +includes adding an email address during registration.** + +It is possible for a user to associate an email address or phone number +with their account, for a number of reasons: + +* for use when logging in, as an alternative to the user id. +* in the case of email, as an alternative contact to help with account recovery. +* in the case of email, to receive notifications of missed messages. + +Before an email address or phone number can be added to a user's account, +or before such an address is used to carry out a password-reset, Synapse must +confirm the operation with the owner of the email address or phone number. +It does this by sending an email or text giving the user a link or token to confirm +receipt. This process is known as '3pid verification'. ('3pid', or 'threepid', +stands for third-party identifier, and we use it to refer to external +identifiers such as email addresses and phone numbers.) + +Previous versions of Synapse delegated the task of 3pid verification to an +identity server by default. In most cases this server is ``vector.im`` or +``matrix.org``. + +In Synapse 1.4.0, for security and privacy reasons, the homeserver will no +longer delegate this task to an identity server by default. Instead, +the server administrator will need to explicitly decide how they would like the +verification messages to be sent. + +In the medium term, the ``vector.im`` and ``matrix.org`` identity servers will +disable support for delegated 3pid verification entirely. However, in order to +ease the transition, they will retain the capability for a limited +period. Delegated email verification will be disabled on Monday 2nd December +2019 (giving roughly 2 months notice). Disabling delegated SMS verification +will follow some time after that once SMS verification support lands in +Synapse. + +Once delegated 3pid verification support has been disabled in the ``vector.im`` and +``matrix.org`` identity servers, all Synapse versions that depend on those +instances will be unable to verify email and phone numbers through them. There +are no imminent plans to remove delegated 3pid verification from Sydent +generally. (Sydent is the identity server project that backs the ``vector.im`` and +``matrix.org`` instances). + +Email +~~~~~ +Following upgrade, to continue verifying email (e.g. as part of the +registration process), admins can either:- + +* Configure Synapse to use an email server. +* Run or choose an identity server which allows delegated email verification + and delegate to it. + +Configure SMTP in Synapse ++++++++++++++++++++++++++ + +To configure an SMTP server for Synapse, modify the configuration section +headed ``email``, and be sure to have at least the ``smtp_host, smtp_port`` +and ``notif_from`` fields filled out. -This release also deprecates the ``email.trust_identity_server_for_password_resets`` option and -replaces it with the ``account_threepid_delegates`` dictionary. This option defines whether the -homeserver should delegate an external server (typically an `identity server -<https://matrix.org/docs/spec/identity_service/r0.2.1>`_) to handle sending password reset or -registration messages via email and SMS. +You may also need to set ``smtp_user``, ``smtp_pass``, and +``require_transport_security``. -If ``email.trust_identity_server_for_password_resets`` is set to ``true``, and +See the `sample configuration file <docs/sample_config.yaml>`_ for more details +on these settings. + +Delegate email to an identity server +++++++++++++++++++++++++++++++++++++ + +Some admins will wish to continue using email verification as part of the +registration process, but will not immediately have an appropriate SMTP server +at hand. + +To this end, we will continue to support email verification delegation via the +``vector.im`` and ``matrix.org`` identity servers for two months. Support for +delegated email verification will be disabled on Monday 2nd December. + +The ``account_threepid_delegates`` dictionary defines whether the homeserver +should delegate an external server (typically an `identity server +<https://matrix.org/docs/spec/identity_service/r0.2.1>`_) to handle sending +confirmation messages via email and SMS. + +So to delegate email verification, in ``homeserver.yaml``, set +``account_threepid_delegates.email`` to the base URL of an identity server. For +example: + +.. code:: yaml + + account_threepid_delegates: + email: https://example.com # Delegate email sending to example.com + +Note that ``account_threepid_delegates.email`` replaces the deprecated +``email.trust_identity_server_for_password_resets``: if +``email.trust_identity_server_for_password_resets`` is set to ``true``, and ``account_threepid_delegates.email`` is not set, then the first entry in -``trusted_third_party_id_servers`` will be used as the account threepid delegate for email. -This is to ensure compatibility with existing Synapse installs that set up external server -handling for these tasks before v1.4.0. If ``email.trust_identity_server_for_password_resets`` -is ``true`` and no trusted identity server domains are configured, Synapse will throw an error. +``trusted_third_party_id_servers`` will be used as the +``account_threepid_delegate`` for email. This is to ensure compatibility with +existing Synapse installs that set up external server handling for these tasks +before v1.4.0. If ``email.trust_identity_server_for_password_resets`` is +``true`` and no trusted identity server domains are configured, Synapse will +report an error and refuse to start. + +If ``email.trust_identity_server_for_password_resets`` is ``false`` or absent +and no ``email`` delegate is configured in ``account_threepid_delegates``, +then Synapse will send email verification messages itself, using the configured +SMTP server (see above). +that type. + +Phone numbers +~~~~~~~~~~~~~ + +Synapse does not support phone-number verification itself, so the only way to +maintain the ability for users to add phone numbers to their accounts will be +by continuing to delegate phone number verification to the ``matrix.org`` and +``vector.im`` identity servers (or another identity server that supports SMS +sending). + +The ``account_threepid_delegates`` dictionary defines whether the homeserver +should delegate an external server (typically an `identity server +<https://matrix.org/docs/spec/identity_service/r0.2.1>`_) to handle sending +confirmation messages via email and SMS. + +So to delegate phone number verification, in ``homeserver.yaml``, set +``account_threepid_delegates.msisdn`` to the base URL of an identity +server. For example: + +.. code:: yaml + + account_threepid_delegates: + msisdn: https://example.com # Delegate sms sending to example.com + +The ``matrix.org`` and ``vector.im`` identity servers will continue to support +delegated phone number verification via SMS until such time as it is possible +for admins to configure their servers to perform phone number verification +directly. More details will follow in a future release. + +Rolling back to v1.3.1 +---------------------- -If ``email.trust_identity_server_for_password_resets`` is ``false`` or absent and a threepid -type in ``account_threepid_delegates`` is not set to a domain, then Synapse will attempt to -send password reset and registration messages for that type. +If you encounter problems with v1.4.0, it should be possible to roll back to +v1.3.1, subject to the following: -Email templates ---------------- +* The 'room statistics' engine was heavily reworked in this release (see + `#5971 <https://github.com/matrix-org/synapse/pull/5971>`_), including + significant changes to the database schema, which are not easily + reverted. This will cause the room statistics engine to stop updating when + you downgrade. -If you have configured a custom template directory with the ``email.template_dir`` option, be -aware that there are new templates regarding registration. ``registration.html`` and -``registration.txt`` have been added and contain the content that is sent to a client upon -registering via an email address. + The room statistics are essentially unused in v1.3.1 (in future versions of + Synapse, they will be used to populate the room directory), so there should + be no loss of functionality. However, the statistics engine will write errors + to the logs, which can be avoided by setting the following in + `homeserver.yaml`: -``registration_success.html`` and ``registration_failure.html`` are also new HTML templates -that will be shown to the user when they click the link in their registration emai , either -showing them a success or failure page (assuming a redirect URL is not configured). + .. code:: yaml -Synapse will expect these files to exist inside the configured template directory. To view the -default templates, see `synapse/res/templates -<https://github.com/matrix-org/synapse/tree/master/synapse/res/templates>`_. + stats: + enabled: false + + Don't forget to re-enable it when you upgrade again, in preparation for its + use in the room directory! Upgrading to v1.2.0 =================== diff --git a/changelog.d/5633.bugfix b/changelog.d/5633.bugfix deleted file mode 100644 index b2ff803b9d..0000000000 --- a/changelog.d/5633.bugfix +++ /dev/null @@ -1 +0,0 @@ -Don't create broken room when power_level_content_override.users does not contain creator_id. \ No newline at end of file diff --git a/changelog.d/5680.misc b/changelog.d/5680.misc deleted file mode 100644 index 46a403a188..0000000000 --- a/changelog.d/5680.misc +++ /dev/null @@ -1 +0,0 @@ -Lay the groundwork for structured logging output. diff --git a/changelog.d/5771.feature b/changelog.d/5771.feature deleted file mode 100644 index f2f4de1fdd..0000000000 --- a/changelog.d/5771.feature +++ /dev/null @@ -1 +0,0 @@ -Make Opentracing work in worker mode. diff --git a/changelog.d/5776.misc b/changelog.d/5776.misc deleted file mode 100644 index 1fb1b9c152..0000000000 --- a/changelog.d/5776.misc +++ /dev/null @@ -1 +0,0 @@ -Update opentracing docs to use the unified `trace` method. diff --git a/changelog.d/5835.feature b/changelog.d/5835.feature deleted file mode 100644 index 3e8bf5068d..0000000000 --- a/changelog.d/5835.feature +++ /dev/null @@ -1 +0,0 @@ -Add the ability to send registration emails from the homeserver rather than delegating to an identity server. diff --git a/changelog.d/5844.misc b/changelog.d/5844.misc deleted file mode 100644 index a0826af0d2..0000000000 --- a/changelog.d/5844.misc +++ /dev/null @@ -1 +0,0 @@ -Retry well-known lookup before the cache expires, giving a grace period where the remote well-known can be down but we still use the old result. diff --git a/changelog.d/5845.feature b/changelog.d/5845.feature deleted file mode 100644 index 7b0dc9a95e..0000000000 --- a/changelog.d/5845.feature +++ /dev/null @@ -1 +0,0 @@ -Add an admin API to purge old rooms from the database. diff --git a/changelog.d/5849.doc b/changelog.d/5849.doc deleted file mode 100644 index fbe62e8633..0000000000 --- a/changelog.d/5849.doc +++ /dev/null @@ -1 +0,0 @@ -Convert documentation to markdown (from rst) diff --git a/changelog.d/5850.feature b/changelog.d/5850.feature deleted file mode 100644 index b565929a54..0000000000 --- a/changelog.d/5850.feature +++ /dev/null @@ -1 +0,0 @@ -Add retry to well-known lookups if we have recently seen a valid well-known record for the server. diff --git a/changelog.d/5852.feature b/changelog.d/5852.feature deleted file mode 100644 index 4a0fc6c542..0000000000 --- a/changelog.d/5852.feature +++ /dev/null @@ -1 +0,0 @@ -Pass opentracing contexts between servers when transmitting EDUs. diff --git a/changelog.d/5853.feature b/changelog.d/5853.feature deleted file mode 100644 index 80a04ae2ee..0000000000 --- a/changelog.d/5853.feature +++ /dev/null @@ -1 +0,0 @@ -Opentracing for device list updates. diff --git a/changelog.d/5855.misc b/changelog.d/5855.misc deleted file mode 100644 index 32db7fbe37..0000000000 --- a/changelog.d/5855.misc +++ /dev/null @@ -1 +0,0 @@ -Opentracing for room and e2e keys. diff --git a/changelog.d/5856.feature b/changelog.d/5856.feature deleted file mode 100644 index f4310b9244..0000000000 --- a/changelog.d/5856.feature +++ /dev/null @@ -1 +0,0 @@ -Add a tag recording a request's authenticated entity and corresponding servlet in opentracing. diff --git a/changelog.d/5857.bugfix b/changelog.d/5857.bugfix deleted file mode 100644 index 008799ccbb..0000000000 --- a/changelog.d/5857.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix database index so that different backup versions can have the same sessions. diff --git a/changelog.d/5859.feature b/changelog.d/5859.feature deleted file mode 100644 index 52df7fc81b..0000000000 --- a/changelog.d/5859.feature +++ /dev/null @@ -1 +0,0 @@ -Add unstable support for MSC2197 (filtered search requests over federation), in order to allow upcoming room directory query performance improvements. diff --git a/changelog.d/5860.misc b/changelog.d/5860.misc deleted file mode 100644 index f9960b17b4..0000000000 --- a/changelog.d/5860.misc +++ /dev/null @@ -1 +0,0 @@ -Remove log line for debugging issue #5407. diff --git a/changelog.d/5863.bugfix b/changelog.d/5863.bugfix deleted file mode 100644 index bceae5be67..0000000000 --- a/changelog.d/5863.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix Synapse looking for config options `password_reset_failure_template` and `password_reset_success_template`, when they are actually `password_reset_template_failure_html`, `password_reset_template_success_html`. diff --git a/changelog.d/5864.feature b/changelog.d/5864.feature deleted file mode 100644 index 40ac11db64..0000000000 --- a/changelog.d/5864.feature +++ /dev/null @@ -1 +0,0 @@ -Correctly retry all hosts returned from SRV when we fail to connect. diff --git a/changelog.d/5868.feature b/changelog.d/5868.feature deleted file mode 100644 index 69605c1ae1..0000000000 --- a/changelog.d/5868.feature +++ /dev/null @@ -1 +0,0 @@ -Add `m.require_identity_server` key to `/versions`'s `unstable_features` section. \ No newline at end of file diff --git a/changelog.d/5875.misc b/changelog.d/5875.misc deleted file mode 100644 index e188c28d2f..0000000000 --- a/changelog.d/5875.misc +++ /dev/null @@ -1 +0,0 @@ -Deprecate the `trusted_third_party_id_servers` option. \ No newline at end of file diff --git a/changelog.d/5876.feature b/changelog.d/5876.feature deleted file mode 100644 index df88193fbd..0000000000 --- a/changelog.d/5876.feature +++ /dev/null @@ -1 +0,0 @@ -Replace `trust_identity_server_for_password_resets` config option with `account_threepid_delegates`. \ No newline at end of file diff --git a/changelog.d/5877.removal b/changelog.d/5877.removal deleted file mode 100644 index b6d84fb401..0000000000 --- a/changelog.d/5877.removal +++ /dev/null @@ -1 +0,0 @@ -Remove shared secret registration from client/r0/register endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/changelog.d/5878.feature b/changelog.d/5878.feature deleted file mode 100644 index d9d6df880e..0000000000 --- a/changelog.d/5878.feature +++ /dev/null @@ -1 +0,0 @@ -Add admin API endpoint for setting whether or not a user is a server administrator. diff --git a/changelog.d/5885.bugfix b/changelog.d/5885.bugfix deleted file mode 100644 index 411d925fd4..0000000000 --- a/changelog.d/5885.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix stack overflow when recovering an appservice which had an outage. diff --git a/changelog.d/5886.misc b/changelog.d/5886.misc deleted file mode 100644 index 22adba3d85..0000000000 --- a/changelog.d/5886.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor the Appservice scheduler code. diff --git a/changelog.d/5892.misc b/changelog.d/5892.misc deleted file mode 100644 index 939fe8c655..0000000000 --- a/changelog.d/5892.misc +++ /dev/null @@ -1 +0,0 @@ -Compatibility with v2 Identity Service APIs other than /lookup. \ No newline at end of file diff --git a/changelog.d/5893.misc b/changelog.d/5893.misc deleted file mode 100644 index 07ee4888dc..0000000000 --- a/changelog.d/5893.misc +++ /dev/null @@ -1 +0,0 @@ -Drop some unused tables. diff --git a/changelog.d/5894.misc b/changelog.d/5894.misc deleted file mode 100644 index fca4485ff7..0000000000 --- a/changelog.d/5894.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing index on users_in_public_rooms to improve the performance of directory queries. diff --git a/changelog.d/5895.feature b/changelog.d/5895.feature deleted file mode 100644 index c394a3772c..0000000000 --- a/changelog.d/5895.feature +++ /dev/null @@ -1 +0,0 @@ -Add config option to sign remote key query responses with a separate key. diff --git a/changelog.d/5896.misc b/changelog.d/5896.misc deleted file mode 100644 index ed47c747bd..0000000000 --- a/changelog.d/5896.misc +++ /dev/null @@ -1 +0,0 @@ -Improve the logging when we have an error when fetching signing keys. diff --git a/changelog.d/5897.feature b/changelog.d/5897.feature deleted file mode 100644 index 1557e559e8..0000000000 --- a/changelog.d/5897.feature +++ /dev/null @@ -1 +0,0 @@ -Switch to using the v2 Identity Service `/lookup` API where available, with fallback to v1. (Implements [MSC2134](https://github.com/matrix-org/matrix-doc/pull/2134) plus id_access_token authentication for v2 Identity Service APIs from [MSC2140](https://github.com/matrix-org/matrix-doc/pull/2140)). diff --git a/changelog.d/5900.feature b/changelog.d/5900.feature deleted file mode 100644 index b62d88a76b..0000000000 --- a/changelog.d/5900.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for config templating. diff --git a/changelog.d/5902.feature b/changelog.d/5902.feature deleted file mode 100644 index 0660f65cfa..0000000000 --- a/changelog.d/5902.feature +++ /dev/null @@ -1 +0,0 @@ -Users with the type of "support" or "bot" are no longer required to consent. \ No newline at end of file diff --git a/changelog.d/5904.feature b/changelog.d/5904.feature deleted file mode 100644 index 43b5304f39..0000000000 --- a/changelog.d/5904.feature +++ /dev/null @@ -1 +0,0 @@ -Let synctl accept a directory of config files. diff --git a/changelog.d/5906.feature b/changelog.d/5906.feature deleted file mode 100644 index 7c789510a6..0000000000 --- a/changelog.d/5906.feature +++ /dev/null @@ -1 +0,0 @@ -Increase max display name size to 256. diff --git a/changelog.d/5909.misc b/changelog.d/5909.misc deleted file mode 100644 index 03d0c4367b..0000000000 --- a/changelog.d/5909.misc +++ /dev/null @@ -1 +0,0 @@ -Fix error message which referred to public_base_url instead of public_baseurl. Thanks to @aaronraimist for the fix! diff --git a/changelog.d/5911.misc b/changelog.d/5911.misc deleted file mode 100644 index fe5a8fd59c..0000000000 --- a/changelog.d/5911.misc +++ /dev/null @@ -1 +0,0 @@ -Add support for database engine-specific schema deltas, based on file extension. \ No newline at end of file diff --git a/changelog.d/5914.feature b/changelog.d/5914.feature deleted file mode 100644 index 85c7bf5963..0000000000 --- a/changelog.d/5914.feature +++ /dev/null @@ -1 +0,0 @@ -Add admin API endpoint for getting whether or not a user is a server administrator. diff --git a/changelog.d/5915.bugfix b/changelog.d/5915.bugfix deleted file mode 100644 index bf5b99fedc..0000000000 --- a/changelog.d/5915.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix 404 for thumbnail download when `dynamic_thumbnails` is `false` and the thumbnail was dynamically generated. Fix reported by rkfg. diff --git a/changelog.d/5920.bugfix b/changelog.d/5920.bugfix deleted file mode 100644 index e45eb0ffee..0000000000 --- a/changelog.d/5920.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a cache-invalidation bug for worker-based deployments. diff --git a/changelog.d/5922.misc b/changelog.d/5922.misc deleted file mode 100644 index 2cc864897e..0000000000 --- a/changelog.d/5922.misc +++ /dev/null @@ -1 +0,0 @@ -Update Buildkite pipeline to use plugins instead of buildkite-agent commands. diff --git a/changelog.d/5926.misc b/changelog.d/5926.misc deleted file mode 100644 index 4383c302ec..0000000000 --- a/changelog.d/5926.misc +++ /dev/null @@ -1 +0,0 @@ -Add link in sample config to the logging config schema. diff --git a/changelog.d/5931.misc b/changelog.d/5931.misc deleted file mode 100644 index ac8e74f5b9..0000000000 --- a/changelog.d/5931.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unnecessary parentheses in return statements. \ No newline at end of file diff --git a/changelog.d/5934.feature b/changelog.d/5934.feature deleted file mode 100644 index eae969a52a..0000000000 --- a/changelog.d/5934.feature +++ /dev/null @@ -1 +0,0 @@ -Redact events in the database that have been redacted for a month. diff --git a/changelog.d/5938.misc b/changelog.d/5938.misc deleted file mode 100644 index b5a3b6ee3b..0000000000 --- a/changelog.d/5938.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unused jenkins/prepare_sytest.sh file. diff --git a/changelog.d/5940.feature b/changelog.d/5940.feature deleted file mode 100644 index 5b69b97fe7..0000000000 --- a/changelog.d/5940.feature +++ /dev/null @@ -1 +0,0 @@ -Add the ability to send registration emails from the homeserver rather than delegating to an identity server. \ No newline at end of file diff --git a/changelog.d/5943.misc b/changelog.d/5943.misc deleted file mode 100644 index 6545e1244a..0000000000 --- a/changelog.d/5943.misc +++ /dev/null @@ -1 +0,0 @@ -Move Buildkite pipeline config to the pipelines repo. diff --git a/changelog.d/5953.misc b/changelog.d/5953.misc deleted file mode 100644 index 38e885f42a..0000000000 --- a/changelog.d/5953.misc +++ /dev/null @@ -1 +0,0 @@ -Update INSTALL.md to say that Python 2 is no longer supported. diff --git a/changelog.d/5962.misc b/changelog.d/5962.misc deleted file mode 100644 index d97d376c36..0000000000 --- a/changelog.d/5962.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unnecessary return statements in the codebase which were the result of a regex run. \ No newline at end of file diff --git a/changelog.d/5963.misc b/changelog.d/5963.misc deleted file mode 100644 index 0d6c3c3d65..0000000000 --- a/changelog.d/5963.misc +++ /dev/null @@ -1 +0,0 @@ -Remove left-over methods from C/S registration API. \ No newline at end of file diff --git a/changelog.d/5964.feature b/changelog.d/5964.feature deleted file mode 100644 index 273c9df026..0000000000 --- a/changelog.d/5964.feature +++ /dev/null @@ -1 +0,0 @@ -Remove `bind_email` and `bind_msisdn` parameters from /register ala MSC2140. \ No newline at end of file diff --git a/changelog.d/5966.bugfix b/changelog.d/5966.bugfix deleted file mode 100644 index b8ef5a7819..0000000000 --- a/changelog.d/5966.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix admin API for listing media in a room not being available with an external media repo. diff --git a/changelog.d/5967.bugfix b/changelog.d/5967.bugfix deleted file mode 100644 index 8d7bf5c2e9..0000000000 --- a/changelog.d/5967.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix list media admin API always returning an error. diff --git a/changelog.d/5969.feature b/changelog.d/5969.feature deleted file mode 100644 index cf603fa0c6..0000000000 --- a/changelog.d/5969.feature +++ /dev/null @@ -1 +0,0 @@ -Replace `trust_identity_server_for_password_resets` config option with `account_threepid_delegates`. diff --git a/changelog.d/5970.docker b/changelog.d/5970.docker deleted file mode 100644 index c9d04da9cd..0000000000 --- a/changelog.d/5970.docker +++ /dev/null @@ -1 +0,0 @@ -Avoid changing UID/GID if they are already correct. diff --git a/changelog.d/5971.bugfix b/changelog.d/5971.bugfix deleted file mode 100644 index 9ea095103b..0000000000 --- a/changelog.d/5971.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix room and user stats tracking. diff --git a/changelog.d/5975.misc b/changelog.d/5975.misc deleted file mode 100644 index 5fcd229b89..0000000000 --- a/changelog.d/5975.misc +++ /dev/null @@ -1 +0,0 @@ -Cleanup event auth type initialisation. \ No newline at end of file diff --git a/changelog.d/5979.feature b/changelog.d/5979.feature deleted file mode 100644 index 94888aa2d3..0000000000 --- a/changelog.d/5979.feature +++ /dev/null @@ -1 +0,0 @@ -Use the v2 Identity Service API for 3PID invites. \ No newline at end of file diff --git a/changelog.d/5980.feature b/changelog.d/5980.feature deleted file mode 100644 index f25d8d81d9..0000000000 --- a/changelog.d/5980.feature +++ /dev/null @@ -1 +0,0 @@ -Add POST /_matrix/client/r0/account/3pid/unbind endpoint from MSC2140 for unbinding a 3PID from an identity server without removing it from the homeserver user account. \ No newline at end of file diff --git a/changelog.d/5981.feature b/changelog.d/5981.feature deleted file mode 100644 index e39514273d..0000000000 --- a/changelog.d/5981.feature +++ /dev/null @@ -1 +0,0 @@ -Setting metrics_flags.known_servers to True in the configuration will publish the synapse_federation_known_servers metric over Prometheus. This represents the total number of servers your server knows about (i.e. is in rooms with), including itself. diff --git a/changelog.d/5982.bugfix b/changelog.d/5982.bugfix deleted file mode 100644 index 3ea281a3a0..0000000000 --- a/changelog.d/5982.bugfix +++ /dev/null @@ -1 +0,0 @@ -Include missing opentracing contexts in outbout replication requests. diff --git a/changelog.d/5983.feature b/changelog.d/5983.feature deleted file mode 100644 index aa23ee6dcd..0000000000 --- a/changelog.d/5983.feature +++ /dev/null @@ -1 +0,0 @@ -Add minimum opentracing for client servlets. diff --git a/changelog.d/5984.bugfix b/changelog.d/5984.bugfix deleted file mode 100644 index 3387bf82bb..0000000000 --- a/changelog.d/5984.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix sending of EDUs when opentracing is enabled with an empty whitelist. diff --git a/changelog.d/5985.feature b/changelog.d/5985.feature deleted file mode 100644 index e5e29504af..0000000000 --- a/changelog.d/5985.feature +++ /dev/null @@ -1 +0,0 @@ -Check at setup that opentracing is installed if it's enabled in the config. diff --git a/changelog.d/5986.feature b/changelog.d/5986.feature deleted file mode 100644 index f56aec1b32..0000000000 --- a/changelog.d/5986.feature +++ /dev/null @@ -1 +0,0 @@ -Trace replication send times. diff --git a/changelog.d/5988.bugfix b/changelog.d/5988.bugfix deleted file mode 100644 index 5c3597cb53..0000000000 --- a/changelog.d/5988.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix invalid references to None while opentracing if the log context slips. diff --git a/changelog.d/5989.misc b/changelog.d/5989.misc deleted file mode 100644 index 9f2525fd3e..0000000000 --- a/changelog.d/5989.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up dependency checking at setup. diff --git a/changelog.d/5991.bugfix b/changelog.d/5991.bugfix deleted file mode 100644 index 5c3597cb53..0000000000 --- a/changelog.d/5991.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix invalid references to None while opentracing if the log context slips. diff --git a/changelog.d/5992.feature b/changelog.d/5992.feature deleted file mode 100644 index 31866c2925..0000000000 --- a/changelog.d/5992.feature +++ /dev/null @@ -1 +0,0 @@ -Give appropriate exit codes when synctl fails. diff --git a/changelog.d/5993.feature b/changelog.d/5993.feature deleted file mode 100644 index 3e8bf5068d..0000000000 --- a/changelog.d/5993.feature +++ /dev/null @@ -1 +0,0 @@ -Add the ability to send registration emails from the homeserver rather than delegating to an identity server. diff --git a/changelog.d/5994.feature b/changelog.d/5994.feature deleted file mode 100644 index 5b69b97fe7..0000000000 --- a/changelog.d/5994.feature +++ /dev/null @@ -1 +0,0 @@ -Add the ability to send registration emails from the homeserver rather than delegating to an identity server. \ No newline at end of file diff --git a/changelog.d/5995.bugfix b/changelog.d/5995.bugfix deleted file mode 100644 index e03ab98bc6..0000000000 --- a/changelog.d/5995.bugfix +++ /dev/null @@ -1 +0,0 @@ -Return a M_MISSING_PARAM if `sid` is not provided to `/account/3pid`. \ No newline at end of file diff --git a/changelog.d/5996.bugfix b/changelog.d/5996.bugfix deleted file mode 100644 index 05e31faaa2..0000000000 --- a/changelog.d/5996.bugfix +++ /dev/null @@ -1 +0,0 @@ -federation_certificate_verification_whitelist now will not cause TypeErrors to be raised (a regression in 1.3). Additionally, it now supports internationalised domain names in their non-canonical representation. diff --git a/changelog.d/5998.bugfix b/changelog.d/5998.bugfix deleted file mode 100644 index 9ea095103b..0000000000 --- a/changelog.d/5998.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix room and user stats tracking. diff --git a/changelog.d/6003.misc b/changelog.d/6003.misc deleted file mode 100644 index 4152d05f87..0000000000 --- a/changelog.d/6003.misc +++ /dev/null @@ -1 +0,0 @@ -Add opentracing span over HTTP push processing. diff --git a/changelog.d/6004.bugfix b/changelog.d/6004.bugfix deleted file mode 100644 index 45c179c8fd..0000000000 --- a/changelog.d/6004.bugfix +++ /dev/null @@ -1 +0,0 @@ -Only count real users when checking for auto-creation of auto-join room. diff --git a/changelog.d/6005.feature b/changelog.d/6005.feature deleted file mode 100644 index ed6491d3e4..0000000000 --- a/changelog.d/6005.feature +++ /dev/null @@ -1 +0,0 @@ -The new Prometheus metric `synapse_build_info` exposes the Python version, OS version, and Synapse version of the running server. diff --git a/changelog.d/6009.misc b/changelog.d/6009.misc deleted file mode 100644 index fea479e1dd..0000000000 --- a/changelog.d/6009.misc +++ /dev/null @@ -1 +0,0 @@ -Small refactor of function arguments and docstrings in RoomMemberHandler. \ No newline at end of file diff --git a/changelog.d/6010.misc b/changelog.d/6010.misc deleted file mode 100644 index 0659f12ebd..0000000000 --- a/changelog.d/6010.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unused `origin` argument on FederationHandler.add_display_name_to_third_party_invite. \ No newline at end of file diff --git a/changelog.d/6011.feature b/changelog.d/6011.feature deleted file mode 100644 index ad16acb12b..0000000000 --- a/changelog.d/6011.feature +++ /dev/null @@ -1 +0,0 @@ -Use account_threepid_delegate.email and account_threepid_delegate.msisdn for validating threepid sessions. \ No newline at end of file diff --git a/changelog.d/6012.feature b/changelog.d/6012.feature deleted file mode 100644 index 25425510c6..0000000000 --- a/changelog.d/6012.feature +++ /dev/null @@ -1 +0,0 @@ -Add report_stats_endpoint option to configure where stats are reported to, if enabled. Contributed by @Sorunome. diff --git a/changelog.d/6013.misc b/changelog.d/6013.misc deleted file mode 100644 index 939fe8c655..0000000000 --- a/changelog.d/6013.misc +++ /dev/null @@ -1 +0,0 @@ -Compatibility with v2 Identity Service APIs other than /lookup. \ No newline at end of file diff --git a/changelog.d/6015.feature b/changelog.d/6015.feature deleted file mode 100644 index 42aaffced9..0000000000 --- a/changelog.d/6015.feature +++ /dev/null @@ -1 +0,0 @@ -Add config option to increase ratelimits for room admins redacting messages. diff --git a/changelog.d/6016.misc b/changelog.d/6016.misc deleted file mode 100644 index 91cf164714..0000000000 --- a/changelog.d/6016.misc +++ /dev/null @@ -1 +0,0 @@ -Add a 'failure_ts' column to the 'destinations' database table. diff --git a/changelog.d/6017.misc b/changelog.d/6017.misc deleted file mode 100644 index 5ccab9c6ca..0000000000 --- a/changelog.d/6017.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some code in the retry logic. diff --git a/changelog.d/6020.bugfix b/changelog.d/6020.bugfix deleted file mode 100644 index 58a7deba9d..0000000000 --- a/changelog.d/6020.bugfix +++ /dev/null @@ -1 +0,0 @@ -Ensure support users can be registered even if MAU limit is reached. diff --git a/changelog.d/6023.misc b/changelog.d/6023.misc deleted file mode 100644 index d80410c22c..0000000000 --- a/changelog.d/6023.misc +++ /dev/null @@ -1 +0,0 @@ -Fix the structured logging tests stomping on the global log configuration for subsequent tests. diff --git a/changelog.d/6024.bugfix b/changelog.d/6024.bugfix deleted file mode 100644 index ddad34595b..0000000000 --- a/changelog.d/6024.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix bug where login error was shown incorrectly on SSO fallback login. diff --git a/changelog.d/6025.bugfix b/changelog.d/6025.bugfix deleted file mode 100644 index 50d7f9aab5..0000000000 --- a/changelog.d/6025.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix bug in calculating the federation retry backoff period. \ No newline at end of file diff --git a/changelog.d/6026.feature b/changelog.d/6026.feature deleted file mode 100644 index 2489ff09b5..0000000000 --- a/changelog.d/6026.feature +++ /dev/null @@ -1 +0,0 @@ -Stop sending federation transactions to servers which have been down for a long time. diff --git a/changelog.d/6028.feature b/changelog.d/6028.feature deleted file mode 100644 index cf603fa0c6..0000000000 --- a/changelog.d/6028.feature +++ /dev/null @@ -1 +0,0 @@ -Replace `trust_identity_server_for_password_resets` config option with `account_threepid_delegates`. diff --git a/changelog.d/6029.bugfix b/changelog.d/6029.bugfix deleted file mode 100644 index 9ea095103b..0000000000 --- a/changelog.d/6029.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix room and user stats tracking. diff --git a/changelog.d/6032.misc b/changelog.d/6032.misc deleted file mode 100644 index ec5b5eb881..0000000000 --- a/changelog.d/6032.misc +++ /dev/null @@ -1 +0,0 @@ -Add developer documentation for using SAML2. diff --git a/changelog.d/6059.bugfix b/changelog.d/6059.bugfix deleted file mode 100644 index 49d5bd3fa0..0000000000 --- a/changelog.d/6059.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix logcontext spam on non-Linux platforms. diff --git a/contrib/docker/README.md b/contrib/docker/README.md index af102f7594..89c1518bd0 100644 --- a/contrib/docker/README.md +++ b/contrib/docker/README.md @@ -1,39 +1,26 @@ -# Synapse Docker - -FIXME: this is out-of-date as of -https://github.com/matrix-org/synapse/issues/5518. Contributions to bring it up -to date would be welcome. - -### Automated configuration - -It is recommended that you use Docker Compose to run your containers, including -this image and a Postgres server. A sample ``docker-compose.yml`` is provided, -including example labels for reverse proxying and other artifacts. - -Read the section about environment variables and set at least mandatory variables, -then run the server: - -``` -docker-compose up -d -``` -If secrets are not specified in the environment variables, they will be generated -as part of the startup. Please ensure these secrets are kept between launches of the -Docker container, as their loss may require users to log in again. +# Synapse Docker -### Manual configuration +### Configuration A sample ``docker-compose.yml`` is provided, including example labels for reverse proxying and other artifacts. The docker-compose file is an example, please comment/uncomment sections that are not suitable for your usecase. Specify a ``SYNAPSE_CONFIG_PATH``, preferably to a persistent path, -to use manual configuration. To generate a fresh ``homeserver.yaml``, simply run: +to use manual configuration. + +To generate a fresh `homeserver.yaml`, you can use the `generate` command. +(See the [documentation](../../docker/README.md#generating-a-configuration-file) +for more information.) You will need to specify appropriate values for at least the +`SYNAPSE_SERVER_NAME` and `SYNAPSE_REPORT_STATS` environment variables. For example: ``` -docker-compose run --rm -e SYNAPSE_SERVER_NAME=my.matrix.host synapse generate +docker-compose run --rm -e SYNAPSE_SERVER_NAME=my.matrix.host -e SYNAPSE_REPORT_STATS=yes synapse generate ``` +(This will also generate necessary signing keys.) + Then, customize your configuration and run the server: ``` diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml index 1e4ee43758..17354b6610 100644 --- a/contrib/docker/docker-compose.yml +++ b/contrib/docker/docker-compose.yml @@ -15,11 +15,7 @@ services: restart: unless-stopped # See the readme for a full documentation of the environment settings environment: - - SYNAPSE_SERVER_NAME=my.matrix.host - - SYNAPSE_REPORT_STATS=no - - SYNAPSE_ENABLE_REGISTRATION=yes - - SYNAPSE_LOG_LEVEL=INFO - - POSTGRES_PASSWORD=changeme + - SYNAPSE_CONFIG_PATH=/data/homeserver.yaml volumes: # You may either store all the files in a local folder - ./files:/data @@ -35,9 +31,23 @@ services: - 8448:8448/tcp # ... or use a reverse proxy, here is an example for traefik: labels: + # The following lines are valid for Traefik version 1.x: - traefik.enable=true - traefik.frontend.rule=Host:my.matrix.Host - traefik.port=8008 + # Alternatively, for Traefik version 2.0: + - traefik.enable=true + - traefik.http.routers.http-synapse.entryPoints=http + - traefik.http.routers.http-synapse.rule=Host(`my.matrix.host`) + - traefik.http.middlewares.https_redirect.redirectscheme.scheme=https + - traefik.http.middlewares.https_redirect.redirectscheme.permanent=true + - traefik.http.routers.http-synapse.middlewares=https_redirect + - traefik.http.routers.https-synapse.entryPoints=https + - traefik.http.routers.https-synapse.rule=Host(`my.matrix.host`) + - traefik.http.routers.https-synapse.service=synapse + - traefik.http.routers.https-synapse.tls=true + - traefik.http.services.synapse.loadbalancer.server.port=8008 + - traefik.http.routers.https-synapse.tls.certResolver=le-ssl db: image: docker.io/postgres:10-alpine @@ -45,6 +55,9 @@ services: environment: - POSTGRES_USER=synapse - POSTGRES_PASSWORD=changeme + # ensure the database gets created correctly + # https://github.com/matrix-org/synapse/blob/master/docs/postgres.md#set-up-database + - POSTGRES_INITDB_ARGS=--encoding=UTF-8 --lc-collate=C --lc-ctype=C volumes: # You may store the database tables in a local folder.. - ./schemas:/var/lib/postgresql/data diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 5ef140ae48..3bbbcfa1b4 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -78,7 +78,7 @@ class InputOutput(object): m = re.match("^join (\S+)$", line) if m: # The `sender` wants to join a room. - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("%s joining %s" % (self.user, room_name)) self.server.join_room(room_name, self.user, self.user) # self.print_line("OK.") @@ -105,7 +105,7 @@ class InputOutput(object): m = re.match("^backfill (\S+)$", line) if m: # we want to backfill a room - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("backfill %s" % room_name) self.server.backfill(room_name) return @@ -339,7 +339,7 @@ def main(stdscr): root_logger = logging.getLogger() formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(lineno)d - " "%(levelname)s - %(message)s" + "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" ) if not os.path.exists("logs"): os.makedirs("logs") diff --git a/contrib/grafana/README.md b/contrib/grafana/README.md index 6a6cc0bed4..ca780d412e 100644 --- a/contrib/grafana/README.md +++ b/contrib/grafana/README.md @@ -1,6 +1,6 @@ # Using the Synapse Grafana dashboard 0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/ -1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst +1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md 2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/ 3. Set up additional recording rules diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index 5b1bfd1679..30a8681f5a 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -18,7 +18,7 @@ "gnetId": null, "graphTooltip": 0, "id": 1, - "iteration": 1561447718159, + "iteration": 1591098104645, "links": [ { "asDropdown": true, @@ -34,6 +34,7 @@ "panels": [ { "collapsed": false, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -52,12 +53,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 1 }, + "hiddenSeries": false, "id": 75, "legend": { "avg": false, @@ -72,7 +75,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -151,6 +156,7 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 9, @@ -158,6 +164,7 @@ "x": 12, "y": 1 }, + "hiddenSeries": false, "id": 33, "legend": { "avg": false, @@ -172,7 +179,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -302,12 +311,14 @@ "dashes": false, "datasource": "$datasource", "fill": 0, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 12, "y": 10 }, + "hiddenSeries": false, "id": 107, "legend": { "avg": false, @@ -322,7 +333,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -425,12 +438,14 @@ "dashes": false, "datasource": "$datasource", "fill": 0, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 19 }, + "hiddenSeries": false, "id": 118, "legend": { "avg": false, @@ -445,7 +460,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -542,6 +559,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -559,13 +577,15 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 29 + "y": 2 }, + "hiddenSeries": false, "id": 5, "legend": { "alignAsTable": false, @@ -584,6 +604,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -686,12 +709,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 29 + "y": 2 }, + "hiddenSeries": false, "id": 37, "legend": { "avg": false, @@ -706,6 +731,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -792,13 +820,15 @@ "editable": true, "error": false, "fill": 0, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 36 + "y": 9 }, + "hiddenSeries": false, "id": 34, "legend": { "avg": false, @@ -813,6 +843,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -880,13 +913,16 @@ "datasource": "$datasource", "description": "Shows the time in which the given percentage of reactor ticks completed, over the sampled timespan", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 36 + "y": 9 }, + "hiddenSeries": false, "id": 105, + "interval": "", "legend": { "avg": false, "current": false, @@ -900,6 +936,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -934,9 +973,10 @@ "refId": "C" }, { - "expr": "", + "expr": "rate(python_twisted_reactor_tick_time_sum{index=~\"$index\",instance=\"$instance\",job=~\"$job\"}[$bucket_size]) / rate(python_twisted_reactor_tick_time_count{index=~\"$index\",instance=\"$instance\",job=~\"$job\"}[$bucket_size])", "format": "time_series", "intervalFactor": 1, + "legendFormat": "{{job}}-{{index}} mean", "refId": "D" } ], @@ -987,14 +1027,16 @@ "dashLength": 10, "dashes": false, "datasource": "$datasource", - "fill": 1, + "fill": 0, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 43 + "y": 16 }, - "id": 50, + "hiddenSeries": false, + "id": 53, "legend": { "avg": false, "current": false, @@ -1008,6 +1050,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -1019,20 +1064,18 @@ "steppedLine": false, "targets": [ { - "expr": "rate(python_twisted_reactor_tick_time_sum{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])/rate(python_twisted_reactor_tick_time_count[$bucket_size])", + "expr": "min_over_time(up{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", - "interval": "", "intervalFactor": 2, "legendFormat": "{{job}}-{{index}}", - "refId": "A", - "step": 20 + "refId": "A" } ], "thresholds": [], "timeFrom": null, "timeRegions": [], "timeShift": null, - "title": "Avg reactor tick time", + "title": "Up", "tooltip": { "shared": true, "sort": 0, @@ -1048,7 +1091,7 @@ }, "yaxes": [ { - "format": "s", + "format": "short", "label": null, "logBase": 1, "max": null, @@ -1061,7 +1104,7 @@ "logBase": 1, "max": null, "min": null, - "show": false + "show": true } ], "yaxis": { @@ -1076,12 +1119,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 43 + "y": 16 }, + "hiddenSeries": false, "id": 49, "legend": { "avg": false, @@ -1096,6 +1141,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -1170,14 +1218,17 @@ "dashLength": 10, "dashes": false, "datasource": "$datasource", - "fill": 0, + "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 50 + "y": 23 }, - "id": 53, + "hiddenSeries": false, + "id": 136, + "interval": "", "legend": { "avg": false, "current": false, @@ -1189,11 +1240,12 @@ }, "lines": true, "linewidth": 1, - "links": [], "nullPointMode": "null", - "paceLength": 10, + "options": { + "dataLinks": [] + }, "percentage": false, - "pointradius": 5, + "pointradius": 2, "points": false, "renderer": "flot", "seriesOverrides": [], @@ -1202,18 +1254,21 @@ "steppedLine": false, "targets": [ { - "expr": "min_over_time(up{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", - "format": "time_series", - "intervalFactor": 2, - "legendFormat": "{{job}}-{{index}}", + "expr": "rate(synapse_http_client_requests{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size])", + "legendFormat": "{{job}}-{{index}} {{method}}", "refId": "A" + }, + { + "expr": "rate(synapse_http_matrixfederationclient_requests{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size])", + "legendFormat": "{{job}}-{{index}} {{method}} (federation)", + "refId": "B" } ], "thresholds": [], "timeFrom": null, "timeRegions": [], "timeShift": null, - "title": "Up", + "title": "Outgoing HTTP request rate", "tooltip": { "shared": true, "sort": 0, @@ -1229,7 +1284,7 @@ }, "yaxes": [ { - "format": "short", + "format": "reqps", "label": null, "logBase": 1, "max": null, @@ -1257,12 +1312,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 50 + "y": 23 }, + "hiddenSeries": false, "id": 120, "legend": { "avg": false, @@ -1277,6 +1334,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 2, "points": false, @@ -1289,6 +1349,7 @@ { "expr": "rate(synapse_http_server_response_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_response_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", + "hide": false, "instant": false, "intervalFactor": 1, "legendFormat": "{{job}}-{{index}} {{method}} {{servlet}} {{tag}}", @@ -1297,6 +1358,7 @@ { "expr": "rate(synapse_background_process_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_background_process_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", + "hide": false, "instant": false, "interval": "", "intervalFactor": 1, @@ -1361,6 +1423,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -1732,6 +1795,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -1750,6 +1814,7 @@ "editable": true, "error": false, "fill": 2, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 8, @@ -1757,6 +1822,7 @@ "x": 0, "y": 31 }, + "hiddenSeries": false, "id": 4, "legend": { "alignAsTable": true, @@ -1775,7 +1841,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -1858,6 +1926,7 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 8, @@ -1865,6 +1934,7 @@ "x": 12, "y": 31 }, + "hiddenSeries": false, "id": 32, "legend": { "avg": false, @@ -1879,7 +1949,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -1948,6 +2020,7 @@ "editable": true, "error": false, "fill": 2, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 8, @@ -1955,7 +2028,8 @@ "x": 0, "y": 39 }, - "id": 23, + "hiddenSeries": false, + "id": 139, "legend": { "alignAsTable": true, "avg": false, @@ -1973,7 +2047,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -1984,7 +2060,7 @@ "steppedLine": false, "targets": [ { - "expr": "rate(synapse_http_server_response_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_response_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", + "expr": "rate(synapse_http_server_in_flight_requests_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_in_flight_requests_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", "interval": "", "intervalFactor": 1, @@ -2059,6 +2135,7 @@ "editable": true, "error": false, "fill": 2, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 8, @@ -2066,6 +2143,7 @@ "x": 12, "y": 39 }, + "hiddenSeries": false, "id": 52, "legend": { "alignAsTable": true, @@ -2084,7 +2162,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -2095,7 +2175,7 @@ "steppedLine": false, "targets": [ { - "expr": "(rate(synapse_http_server_response_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_response_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])) / rate(synapse_http_server_response_count{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", + "expr": "(rate(synapse_http_server_in_flight_requests_ru_utime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])+rate(synapse_http_server_in_flight_requests_ru_stime_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])) / rate(synapse_http_server_requests_received{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", "interval": "", "intervalFactor": 2, @@ -2167,6 +2247,7 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 8, @@ -2174,6 +2255,7 @@ "x": 0, "y": 47 }, + "hiddenSeries": false, "id": 7, "legend": { "alignAsTable": true, @@ -2191,7 +2273,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -2202,7 +2286,7 @@ "steppedLine": false, "targets": [ { - "expr": "rate(synapse_http_server_response_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", + "expr": "rate(synapse_http_server_in_flight_requests_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", "interval": "", "intervalFactor": 2, @@ -2260,6 +2344,7 @@ "editable": true, "error": false, "fill": 2, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 8, @@ -2267,6 +2352,7 @@ "x": 12, "y": 47 }, + "hiddenSeries": false, "id": 47, "legend": { "alignAsTable": true, @@ -2285,7 +2371,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -2352,12 +2440,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 55 }, + "hiddenSeries": false, "id": 103, "legend": { "avg": false, @@ -2372,7 +2462,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -2439,6 +2531,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -2454,12 +2547,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 32 }, + "hiddenSeries": false, "id": 99, "legend": { "avg": false, @@ -2474,7 +2569,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -2542,12 +2639,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 12, "y": 32 }, + "hiddenSeries": false, "id": 101, "legend": { "avg": false, @@ -2562,7 +2661,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -2628,6 +2729,93 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 41 + }, + "hiddenSeries": false, + "id": 138, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "dataLinks": [] + }, + "percentage": false, + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_background_process_in_flight_count{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}", + "legendFormat": "{{job}}-{{index}} {{name}}", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Background jobs in flight", + "tooltip": { + "shared": false, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "title": "Background jobs", @@ -2635,6 +2823,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -2650,12 +2839,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, - "y": 61 + "y": 33 }, + "hiddenSeries": false, "id": 79, "legend": { "avg": false, @@ -2670,6 +2861,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -2684,8 +2878,13 @@ "expr": "sum(rate(synapse_federation_client_sent_transactions{instance=\"$instance\"}[$bucket_size]))", "format": "time_series", "intervalFactor": 1, - "legendFormat": "txn rate", + "legendFormat": "successful txn rate", "refId": "A" + }, + { + "expr": "sum(rate(synapse_util_metrics_block_count{block_name=\"_send_new_transaction\",instance=\"$instance\"}[$bucket_size]) - ignoring (block_name) rate(synapse_federation_client_sent_transactions{instance=\"$instance\"}[$bucket_size]))", + "legendFormat": "failed txn rate", + "refId": "B" } ], "thresholds": [], @@ -2736,12 +2935,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 12, - "y": 61 + "y": 33 }, + "hiddenSeries": false, "id": 83, "legend": { "avg": false, @@ -2756,6 +2957,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -2829,12 +3033,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, - "y": 70 + "y": 42 }, + "hiddenSeries": false, "id": 109, "legend": { "avg": false, @@ -2849,6 +3055,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -2923,12 +3132,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 12, - "y": 70 + "y": 42 }, + "hiddenSeries": false, "id": 111, "legend": { "avg": false, @@ -2943,6 +3154,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3002,6 +3216,144 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "description": "", + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 9, + "w": 12, + "x": 0, + "y": 51 + }, + "hiddenSeries": false, + "id": 140, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "dataLinks": [] + }, + "paceLength": 10, + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_federation_send_queue_presence_changed_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "format": "time_series", + "interval": "", + "intervalFactor": 1, + "legendFormat": "presence changed", + "refId": "A" + }, + { + "expr": "synapse_federation_send_queue_presence_map_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 1, + "legendFormat": "presence map", + "refId": "B" + }, + { + "expr": "synapse_federation_send_queue_presence_destinations_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 1, + "legendFormat": "presence destinations", + "refId": "E" + }, + { + "expr": "synapse_federation_send_queue_keyed_edu_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 1, + "legendFormat": "keyed edus", + "refId": "C" + }, + { + "expr": "synapse_federation_send_queue_edus_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 1, + "legendFormat": "other edus", + "refId": "D" + }, + { + "expr": "synapse_federation_send_queue_pos_time_size{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}", + "format": "time_series", + "hide": false, + "interval": "", + "intervalFactor": 1, + "legendFormat": "stream positions", + "refId": "F" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Outgoing EDU queues on master", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "none", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "title": "Federation", @@ -3009,6 +3361,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -3024,12 +3377,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { - "h": 7, + "h": 8, "w": 12, "x": 0, - "y": 62 + "y": 34 }, + "hiddenSeries": false, "id": 51, "legend": { "avg": false, @@ -3044,6 +3399,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3112,6 +3470,95 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "description": "", + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 34 + }, + "hiddenSeries": false, + "id": 134, + "legend": { + "avg": false, + "current": false, + "hideZero": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "dataLinks": [] + }, + "percentage": false, + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "topk(10,synapse_pushers{job=~\"$job\",index=~\"$index\", instance=\"$instance\"})", + "legendFormat": "{{kind}} {{app_id}}", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Active pusher instances by app", + "tooltip": { + "shared": false, + "sort": 2, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, @@ -3120,6 +3567,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -3135,12 +3583,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 35 + "y": 52 }, + "hiddenSeries": false, "id": 48, "legend": { "avg": false, @@ -3155,7 +3605,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3225,12 +3677,14 @@ "datasource": "$datasource", "description": "Shows the time in which the given percentage of database queries were scheduled, over the sampled timespan", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 35 + "y": 52 }, + "hiddenSeries": false, "id": 104, "legend": { "alignAsTable": true, @@ -3246,7 +3700,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3340,13 +3796,15 @@ "editable": true, "error": false, "fill": 0, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 42 + "y": 59 }, + "hiddenSeries": false, "id": 10, "legend": { "avg": false, @@ -3363,7 +3821,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3432,13 +3892,15 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 42 + "y": 59 }, + "hiddenSeries": false, "id": 11, "legend": { "avg": false, @@ -3455,7 +3917,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3523,6 +3987,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -3540,13 +4005,15 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 13, "w": 12, "x": 0, - "y": 36 + "y": 67 }, + "hiddenSeries": false, "id": 12, "legend": { "alignAsTable": true, @@ -3562,6 +4029,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3630,13 +4100,15 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 13, "w": 12, "x": 12, - "y": 36 + "y": 67 }, + "hiddenSeries": false, "id": 26, "legend": { "alignAsTable": true, @@ -3652,6 +4124,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3720,13 +4195,15 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 13, "w": 12, "x": 0, - "y": 49 + "y": 80 }, + "hiddenSeries": false, "id": 13, "legend": { "alignAsTable": true, @@ -3742,6 +4219,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3753,7 +4233,7 @@ "steppedLine": false, "targets": [ { - "expr": "rate(synapse_util_metrics_block_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\",block_name!=\"wrapped_request_handler\"}[$bucket_size])", + "expr": "rate(synapse_util_metrics_block_db_txn_duration_seconds{instance=\"$instance\",job=~\"$job\",index=~\"$index\"}[$bucket_size])", "format": "time_series", "interval": "", "intervalFactor": 2, @@ -3807,16 +4287,19 @@ "dashLength": 10, "dashes": false, "datasource": "$datasource", + "description": "The time each database transaction takes to execute, on average, broken down by metrics block.", "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 13, "w": 12, "x": 12, - "y": 49 + "y": 80 }, + "hiddenSeries": false, "id": 27, "legend": { "alignAsTable": true, @@ -3832,6 +4315,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3856,7 +4342,7 @@ "timeFrom": null, "timeRegions": [], "timeShift": null, - "title": "Average Database Time per Block", + "title": "Average Database Transaction time, by Block", "tooltip": { "shared": false, "sort": 0, @@ -3900,13 +4386,15 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 13, "w": 12, "x": 0, - "y": 62 + "y": 93 }, + "hiddenSeries": false, "id": 28, "legend": { "avg": false, @@ -3921,6 +4409,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -3989,13 +4480,15 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 13, "w": 12, "x": 12, - "y": 62 + "y": 93 }, + "hiddenSeries": false, "id": 25, "legend": { "avg": false, @@ -4010,6 +4503,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -4076,6 +4572,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -4094,6 +4591,7 @@ "editable": true, "error": false, "fill": 0, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 10, @@ -4101,6 +4599,7 @@ "x": 0, "y": 37 }, + "hiddenSeries": false, "id": 1, "legend": { "alignAsTable": true, @@ -4118,6 +4617,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4187,6 +4689,7 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 10, @@ -4194,6 +4697,7 @@ "x": 12, "y": 37 }, + "hiddenSeries": false, "id": 8, "legend": { "alignAsTable": true, @@ -4210,6 +4714,9 @@ "linewidth": 2, "links": [], "nullPointMode": "connected", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4278,6 +4785,7 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 10, @@ -4285,6 +4793,7 @@ "x": 0, "y": 47 }, + "hiddenSeries": false, "id": 38, "legend": { "alignAsTable": true, @@ -4301,6 +4810,9 @@ "linewidth": 2, "links": [], "nullPointMode": "connected", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4366,12 +4878,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 10, "w": 12, "x": 12, "y": 47 }, + "hiddenSeries": false, "id": 39, "legend": { "alignAsTable": true, @@ -4387,6 +4901,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4453,12 +4970,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 57 }, + "hiddenSeries": false, "id": 65, "legend": { "alignAsTable": true, @@ -4474,6 +4993,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4540,6 +5062,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -4555,12 +5078,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 66 }, + "hiddenSeries": false, "id": 91, "legend": { "avg": false, @@ -4575,6 +5100,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4645,6 +5173,7 @@ "editable": true, "error": false, "fill": 1, + "fillGradient": 0, "grid": {}, "gridPos": { "h": 9, @@ -4652,6 +5181,7 @@ "x": 12, "y": 66 }, + "hiddenSeries": false, "id": 21, "legend": { "alignAsTable": true, @@ -4667,6 +5197,9 @@ "linewidth": 2, "links": [], "nullPointMode": "null as zero", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4733,12 +5266,14 @@ "datasource": "$datasource", "description": "'gen 0' shows the number of objects allocated since the last gen0 GC.\n'gen 1' / 'gen 2' show the number of gen0/gen1 GCs since the last gen1/gen2 GC.", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 75 }, + "hiddenSeries": false, "id": 89, "legend": { "avg": false, @@ -4755,6 +5290,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4826,12 +5364,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 12, "y": 75 }, + "hiddenSeries": false, "id": 93, "legend": { "avg": false, @@ -4846,6 +5386,9 @@ "linewidth": 1, "links": [], "nullPointMode": "connected", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -4911,12 +5454,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, "y": 84 }, + "hiddenSeries": false, "id": 95, "legend": { "avg": false, @@ -4931,6 +5476,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 5, "points": false, @@ -5019,6 +5567,7 @@ "show": true }, "links": [], + "options": {}, "reverseYBuckets": false, "targets": [ { @@ -5060,6 +5609,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -5075,12 +5625,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 67 + "y": 39 }, + "hiddenSeries": false, "id": 2, "legend": { "avg": false, @@ -5095,6 +5647,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5194,12 +5749,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 67 + "y": 39 }, + "hiddenSeries": false, "id": 41, "legend": { "avg": false, @@ -5214,6 +5771,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5282,12 +5842,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 74 + "y": 46 }, + "hiddenSeries": false, "id": 42, "legend": { "avg": false, @@ -5302,6 +5864,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5369,12 +5934,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 74 + "y": 46 }, + "hiddenSeries": false, "id": 43, "legend": { "avg": false, @@ -5389,6 +5956,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5456,12 +6026,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 0, - "y": 81 + "y": 53 }, + "hiddenSeries": false, "id": 113, "legend": { "avg": false, @@ -5476,6 +6048,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5492,6 +6067,13 @@ "intervalFactor": 1, "legendFormat": "{{job}}-{{index}} {{stream_name}}", "refId": "A" + }, + { + "expr": "synapse_replication_tcp_resource_total_connections{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "{{job}}-{{index}}", + "refId": "B" } ], "thresholds": [], @@ -5518,7 +6100,7 @@ "label": null, "logBase": 1, "max": null, - "min": null, + "min": "0", "show": true }, { @@ -5542,12 +6124,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 7, "w": 12, "x": 12, - "y": 81 + "y": 53 }, + "hiddenSeries": false, "id": 115, "legend": { "avg": false, @@ -5562,6 +6146,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5573,7 +6160,7 @@ "steppedLine": false, "targets": [ { - "expr": "rate(synapse_replication_tcp_protocol_close_reason{job=\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size])", + "expr": "rate(synapse_replication_tcp_protocol_close_reason{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size])", "format": "time_series", "intervalFactor": 1, "legendFormat": "{{job}}-{{index}} {{reason_type}}", @@ -5628,6 +6215,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -5643,12 +6231,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, - "y": 13 + "y": 58 }, + "hiddenSeries": false, "id": 67, "legend": { "avg": false, @@ -5663,7 +6253,9 @@ "linewidth": 1, "links": [], "nullPointMode": "connected", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5679,7 +6271,7 @@ "format": "time_series", "interval": "", "intervalFactor": 1, - "legendFormat": "{{job}}-{{index}} ", + "legendFormat": "{{job}}-{{index}} {{name}}", "refId": "A" } ], @@ -5731,12 +6323,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 12, - "y": 13 + "y": 58 }, + "hiddenSeries": false, "id": 71, "legend": { "avg": false, @@ -5751,7 +6345,9 @@ "linewidth": 1, "links": [], "nullPointMode": "connected", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5819,12 +6415,14 @@ "dashes": false, "datasource": "$datasource", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 9, "w": 12, "x": 0, - "y": 22 + "y": 67 }, + "hiddenSeries": false, "id": 121, "interval": "", "legend": { @@ -5840,7 +6438,9 @@ "linewidth": 1, "links": [], "nullPointMode": "connected", - "options": {}, + "options": { + "dataLinks": [] + }, "paceLength": 10, "percentage": false, "pointradius": 5, @@ -5909,6 +6509,7 @@ }, { "collapsed": true, + "datasource": null, "gridPos": { "h": 1, "w": 24, @@ -5938,7 +6539,7 @@ "h": 8, "w": 12, "x": 0, - "y": 14 + "y": 41 }, "heatmap": {}, "hideZeroBuckets": true, @@ -5993,12 +6594,14 @@ "datasource": "$datasource", "description": "Number of rooms with the given number of forward extremities or fewer.\n\nThis is only updated once an hour.", "fill": 0, + "fillGradient": 0, "gridPos": { "h": 8, "w": 12, "x": 12, - "y": 14 + "y": 41 }, + "hiddenSeries": false, "id": 124, "interval": "", "legend": { @@ -6014,7 +6617,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 2, "points": false, @@ -6095,7 +6700,7 @@ "h": 8, "w": 12, "x": 0, - "y": 22 + "y": 49 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6150,12 +6755,14 @@ "datasource": "$datasource", "description": "For a given percentage P, the number X where P% of events were persisted to rooms with X forward extremities or fewer.", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 8, "w": 12, "x": 12, - "y": 22 + "y": 49 }, + "hiddenSeries": false, "id": 128, "legend": { "avg": false, @@ -6170,7 +6777,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 2, "points": false, @@ -6270,7 +6879,7 @@ "h": 8, "w": 12, "x": 0, - "y": 30 + "y": 57 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6325,12 +6934,14 @@ "datasource": "$datasource", "description": "For given percentage P, the number X where P% of events were persisted to rooms with X stale forward extremities or fewer.\n\nStale forward extremities are those that were in the previous set of extremities as well as the new.", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 8, "w": 12, "x": 12, - "y": 30 + "y": 57 }, + "hiddenSeries": false, "id": 130, "legend": { "avg": false, @@ -6345,7 +6956,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 2, "points": false, @@ -6445,7 +7058,7 @@ "h": 8, "w": 12, "x": 0, - "y": 38 + "y": 65 }, "heatmap": {}, "hideZeroBuckets": true, @@ -6500,12 +7113,14 @@ "datasource": "$datasource", "description": "For a given percentage P, the number X where P% of state resolution operations took place over X state groups or fewer.", "fill": 1, + "fillGradient": 0, "gridPos": { "h": 8, "w": 12, "x": 12, - "y": 38 + "y": 65 }, + "hiddenSeries": false, "id": 132, "interval": "", "legend": { @@ -6521,7 +7136,9 @@ "linewidth": 1, "links": [], "nullPointMode": "null", - "options": {}, + "options": { + "dataLinks": [] + }, "percentage": false, "pointradius": 2, "points": false, @@ -6607,7 +7224,7 @@ } ], "refresh": "5m", - "schemaVersion": 18, + "schemaVersion": 22, "style": "dark", "tags": [ "matrix" @@ -6616,7 +7233,6 @@ "list": [ { "current": { - "tags": [], "text": "Prometheus", "value": "Prometheus" }, @@ -6638,6 +7254,7 @@ "auto_count": 100, "auto_min": "30s", "current": { + "selected": false, "text": "auto", "value": "$__auto_interval_bucket_size" }, @@ -6719,9 +7336,9 @@ "allFormat": "regex wildcard", "allValue": "", "current": { - "text": "All", + "text": "synapse", "value": [ - "$__all" + "synapse" ] }, "datasource": "$datasource", @@ -6750,6 +7367,7 @@ "allFormat": "regex wildcard", "allValue": ".*", "current": { + "selected": false, "text": "All", "value": "$__all" }, @@ -6810,5 +7428,5 @@ "timezone": "", "title": "Synapse", "uid": "000000012", - "version": 10 + "version": 29 } \ No newline at end of file diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py index 9db8725eee..4619f0e3c1 100644 --- a/contrib/graph/graph2.py +++ b/contrib/graph/graph2.py @@ -36,7 +36,7 @@ def make_graph(db_name, room_id, file_prefix, limit): args = [room_id] if limit: - sql += " ORDER BY topological_ordering DESC, stream_ordering DESC " "LIMIT ?" + sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" args.append(limit) @@ -53,7 +53,7 @@ def make_graph(db_name, room_id, file_prefix, limit): for event in events: c = conn.execute( - "SELECT state_group FROM event_to_state_groups " "WHERE event_id = ?", + "SELECT state_group FROM event_to_state_groups WHERE event_id = ?", (event.event_id,), ) diff --git a/contrib/systemd-with-workers/README.md b/contrib/systemd-with-workers/README.md index 74b261e9fb..8d21d532bd 100644 --- a/contrib/systemd-with-workers/README.md +++ b/contrib/systemd-with-workers/README.md @@ -1,150 +1,2 @@ -# Setup Synapse with Workers and Systemd - -This is a setup for managing synapse with systemd including support for -managing workers. It provides a `matrix-synapse`, as well as a -`matrix-synapse-worker@` service for any workers you require. Additionally to -group the required services it sets up a `matrix.target`. You can use this to -automatically start any bot- or bridge-services. More on this in -[Bots and Bridges](#bots-and-bridges). - -See the folder [system](system) for any service and target files. - -The folder [workers](workers) contains an example configuration for the -`federation_reader` worker. Pay special attention to the name of the -configuration file. In order to work with the `matrix-synapse-worker@.service` -service, it needs to have the exact same name as the worker app. - -This setup expects neither the homeserver nor any workers to fork. Forking is -handled by systemd. - -## Setup - -1. Adjust your matrix configs. Make sure that the worker config files have the -exact same name as the worker app. Compare `matrix-synapse-worker@.service` for -why. You can find an example worker config in the [workers](workers) folder. See -below for relevant settings in the `homeserver.yaml`. -2. Copy the `*.service` and `*.target` files in [system](system) to -`/etc/systemd/system`. -3. `systemctl enable matrix-synapse.service` this adds the homeserver -app to the `matrix.target` -4. *Optional.* `systemctl enable -matrix-synapse-worker@federation_reader.service` this adds the federation_reader -app to the `matrix-synapse.service` -5. *Optional.* Repeat step 4 for any additional workers you require. -6. *Optional.* Add any bots or bridges by enabling them. -7. Start all matrix related services via `systemctl start matrix.target` -8. *Optional.* Enable autostart of all matrix related services on system boot -via `systemctl enable matrix.target` - -## Usage - -After you have setup you can use the following commands to manage your synapse -installation: - -``` -# Start matrix-synapse, all workers and any enabled bots or bridges. -systemctl start matrix.target - -# Restart matrix-synapse and all workers (not necessarily restarting bots -# or bridges, see "Bots and Bridges") -systemctl restart matrix-synapse.service - -# Stop matrix-synapse and all workers (not necessarily restarting bots -# or bridges, see "Bots and Bridges") -systemctl stop matrix-synapse.service - -# Restart a specific worker (i. e. federation_reader), the homeserver is -# unaffected by this. -systemctl restart matrix-synapse-worker@federation_reader.service - -# Add a new worker (assuming all configs are setup already) -systemctl enable matrix-synapse-worker@federation_writer.service -systemctl restart matrix-synapse.service -``` - -## The Configs - -Make sure the `worker_app` is set in the `homeserver.yaml` and it does not fork. - -``` -worker_app: synapse.app.homeserver -daemonize: false -``` - -None of the workers should fork, as forking is handled by systemd. Hence make -sure this is present in all worker config files. - -``` -worker_daemonize: false -``` - -The config files of all workers are expected to be located in -`/etc/matrix-synapse/workers`. If you want to use a different location you have -to edit the provided `*.service` files accordingly. - -## Bots and Bridges - -Most bots and bridges do not care if the homeserver goes down or is restarted. -Depending on the implementation this may crash them though. So look up the docs -or ask the community of the specific bridge or bot you want to run to make sure -you choose the correct setup. - -Whichever configuration you choose, after the setup the following will enable -automatically starting (and potentially restarting) your bot/bridge with the -`matrix.target`. - -``` -systemctl enable <yourBotOrBridgeName>.service -``` - -**Note** that from an inactive synapse the bots/bridges will only be started with -synapse if you start the `matrix.target`, not if you start the -`matrix-synapse.service`. This is on purpose. Think of `matrix-synapse.service` -as *just* synapse, but `matrix.target` being anything matrix related, including -synapse and any and all enabled bots and bridges. - -### Start with synapse but ignore synapse going down - -If the bridge can handle shutdowns of the homeserver you'll want to install the -service in the `matrix.target` and optionally add a -`After=matrix-synapse.service` dependency to have the bot/bridge start after -synapse on starting everything. - -In this case the service file should look like this. - -``` -[Unit] -# ... -# Optional, this will only ensure that if you start everything, synapse will -# be started before the bot/bridge will be started. -After=matrix-synapse.service - -[Service] -# ... - -[Install] -WantedBy=matrix.target -``` - -### Stop/restart when synapse stops/restarts - -If the bridge can't handle shutdowns of the homeserver you'll still want to -install the service in the `matrix.target` but also have to specify the -`After=matrix-synapse.service` *and* `BindsTo=matrix-synapse.service` -dependencies to have the bot/bridge stop/restart with synapse. - -In this case the service file should look like this. - -``` -[Unit] -# ... -# Mandatory -After=matrix-synapse.service -BindsTo=matrix-synapse.service - -[Service] -# ... - -[Install] -WantedBy=matrix.target -``` +The documentation for using systemd to manage synapse workers is now part of +the main synapse distribution. See [docs/systemd-with-workers](../../docs/systemd-with-workers). diff --git a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service b/contrib/systemd-with-workers/system/matrix-synapse-worker@.service deleted file mode 100644 index 3507e2e989..0000000000 --- a/contrib/systemd-with-workers/system/matrix-synapse-worker@.service +++ /dev/null @@ -1,19 +0,0 @@ -[Unit] -Description=Synapse Matrix Worker -After=matrix-synapse.service -BindsTo=matrix-synapse.service - -[Service] -Type=notify -NotifyAccess=main -User=matrix-synapse -WorkingDirectory=/var/lib/matrix-synapse -EnvironmentFile=/etc/default/matrix-synapse -ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.%i --config-path=/etc/matrix-synapse/homeserver.yaml --config-path=/etc/matrix-synapse/conf.d/ --config-path=/etc/matrix-synapse/workers/%i.yaml -ExecReload=/bin/kill -HUP $MAINPID -Restart=always -RestartSec=3 -SyslogIdentifier=matrix-synapse-%i - -[Install] -WantedBy=matrix-synapse.service diff --git a/contrib/systemd-with-workers/system/matrix.target b/contrib/systemd-with-workers/system/matrix.target deleted file mode 100644 index aff97d03ef..0000000000 --- a/contrib/systemd-with-workers/system/matrix.target +++ /dev/null @@ -1,7 +0,0 @@ -[Unit] -Description=Contains matrix services like synapse, bridges and bots -After=network.target -AllowIsolate=no - -[Install] -WantedBy=multi-user.target diff --git a/contrib/systemd/README.md b/contrib/systemd/README.md new file mode 100644 index 0000000000..5d42b3464f --- /dev/null +++ b/contrib/systemd/README.md @@ -0,0 +1,17 @@ +# Setup Synapse with Systemd +This is a setup for managing synapse with a user contributed systemd unit +file. It provides a `matrix-synapse` systemd unit file that should be tailored +to accommodate your installation in accordance with the installation +instructions provided in [installation instructions](../../INSTALL.md). + +## Setup +1. Under the service section, ensure the `User` variable matches which user +you installed synapse under and wish to run it as. +2. Under the service section, ensure the `WorkingDirectory` variable matches +where you have installed synapse. +3. Under the service section, ensure the `ExecStart` variable matches the +appropriate locations of your installation. +4. Copy the `matrix-synapse.service` to `/etc/systemd/system/` +5. Start Synapse: `sudo systemctl start matrix-synapse` +6. Verify Synapse is running: `sudo systemctl status matrix-synapse` +7. *optional* Enable Synapse to start at system boot: `sudo systemctl enable matrix-synapse` diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service index 38d369ea3d..a754078410 100644 --- a/contrib/systemd/matrix-synapse.service +++ b/contrib/systemd/matrix-synapse.service @@ -4,14 +4,20 @@ # systemctl enable matrix-synapse # systemctl start matrix-synapse # +# This assumes that Synapse has been installed by a user named +# synapse. +# # This assumes that Synapse has been installed in a virtualenv in -# /opt/synapse/env. +# the user's home directory: `/home/synapse/synapse/env`. # # **NOTE:** This is an example service file that may change in the future. If you # wish to use this please copy rather than symlink it. [Unit] Description=Synapse Matrix homeserver +# If you are using postgresql to persist data, uncomment this line to make sure +# synapse starts after the postgresql service. +# After=postgresql.service [Service] Type=notify @@ -22,8 +28,8 @@ Restart=on-abort User=synapse Group=nogroup -WorkingDirectory=/opt/synapse -ExecStart=/opt/synapse/env/bin/python -m synapse.app.homeserver --config-path=/opt/synapse/homeserver.yaml +WorkingDirectory=/home/synapse/synapse +ExecStart=/home/synapse/synapse/env/bin/python -m synapse.app.homeserver --config-path=/home/synapse/synapse/homeserver.yaml SyslogIdentifier=matrix-synapse # adjust the cache factor if necessary diff --git a/debian/build_virtualenv b/debian/build_virtualenv index 2791896052..4c9aabcac3 100755 --- a/debian/build_virtualenv +++ b/debian/build_virtualenv @@ -36,7 +36,6 @@ esac dh_virtualenv \ --install-suffix "matrix-synapse" \ --builtin-venv \ - --setuptools \ --python "$SNAKE" \ --upgrade-pip \ --preinstall="lxml" \ @@ -85,6 +84,9 @@ PYTHONPATH="$tmpdir" \ ' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml" +# build the log config file +"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_log_config" \ + --output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml" # add a dependency on the right version of python to substvars. PYPKG=`basename $SNAKE` diff --git a/debian/changelog b/debian/changelog index 76efc442d7..f50a102b04 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,160 @@ +matrix-synapse-py3 (1.14.0) stable; urgency=medium + + * New synapse release 1.14.0. + + -- Synapse Packaging team <packages@matrix.org> Thu, 28 May 2020 10:37:27 +0000 + +matrix-synapse-py3 (1.13.0) stable; urgency=medium + + [ Patrick Cloke ] + * Add information about .well-known files to Debian installation scripts. + + [ Synapse Packaging team ] + * New synapse release 1.13.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 19 May 2020 09:16:56 -0400 + +matrix-synapse-py3 (1.12.4) stable; urgency=medium + + * New synapse release 1.12.4. + + -- Synapse Packaging team <packages@matrix.org> Thu, 23 Apr 2020 10:58:14 -0400 + +matrix-synapse-py3 (1.12.3) stable; urgency=medium + + [ Richard van der Hoff ] + * Update the Debian build scripts to handle the new installation paths + for the support libraries introduced by Pillow 7.1.1. + + [ Synapse Packaging team ] + * New synapse release 1.12.3. + + -- Synapse Packaging team <packages@matrix.org> Fri, 03 Apr 2020 10:55:03 +0100 + +matrix-synapse-py3 (1.12.2) stable; urgency=medium + + * New synapse release 1.12.2. + + -- Synapse Packaging team <packages@matrix.org> Mon, 02 Apr 2020 19:02:17 +0000 + +matrix-synapse-py3 (1.12.1) stable; urgency=medium + + * New synapse release 1.12.1. + + -- Synapse Packaging team <packages@matrix.org> Mon, 02 Apr 2020 11:30:47 +0000 + +matrix-synapse-py3 (1.12.0) stable; urgency=medium + + * New synapse release 1.12.0. + + -- Synapse Packaging team <packages@matrix.org> Mon, 23 Mar 2020 12:13:03 +0000 + +matrix-synapse-py3 (1.11.1) stable; urgency=medium + + * New synapse release 1.11.1. + + -- Synapse Packaging team <packages@matrix.org> Tue, 03 Mar 2020 15:01:22 +0000 + +matrix-synapse-py3 (1.11.0) stable; urgency=medium + + * New synapse release 1.11.0. + + -- Synapse Packaging team <packages@matrix.org> Fri, 21 Feb 2020 08:54:34 +0000 + +matrix-synapse-py3 (1.10.1) stable; urgency=medium + + * New synapse release 1.10.1. + + -- Synapse Packaging team <packages@matrix.org> Mon, 17 Feb 2020 16:27:28 +0000 + +matrix-synapse-py3 (1.10.0) stable; urgency=medium + + * New synapse release 1.10.0. + + -- Synapse Packaging team <packages@matrix.org> Wed, 12 Feb 2020 12:18:54 +0000 + +matrix-synapse-py3 (1.9.1) stable; urgency=medium + + * New synapse release 1.9.1. + + -- Synapse Packaging team <packages@matrix.org> Tue, 28 Jan 2020 13:09:23 +0000 + +matrix-synapse-py3 (1.9.0) stable; urgency=medium + + * New synapse release 1.9.0. + + -- Synapse Packaging team <packages@matrix.org> Thu, 23 Jan 2020 12:56:31 +0000 + +matrix-synapse-py3 (1.8.0) stable; urgency=medium + + [ Richard van der Hoff ] + * Automate generation of the default log configuration file. + + [ Synapse Packaging team ] + * New synapse release 1.8.0. + + -- Synapse Packaging team <packages@matrix.org> Thu, 09 Jan 2020 11:39:27 +0000 + +matrix-synapse-py3 (1.7.3) stable; urgency=medium + + * New synapse release 1.7.3. + + -- Synapse Packaging team <packages@matrix.org> Tue, 31 Dec 2019 10:45:04 +0000 + +matrix-synapse-py3 (1.7.2) stable; urgency=medium + + * New synapse release 1.7.2. + + -- Synapse Packaging team <packages@matrix.org> Fri, 20 Dec 2019 10:56:50 +0000 + +matrix-synapse-py3 (1.7.1) stable; urgency=medium + + * New synapse release 1.7.1. + + -- Synapse Packaging team <packages@matrix.org> Wed, 18 Dec 2019 09:37:59 +0000 + +matrix-synapse-py3 (1.7.0) stable; urgency=medium + + * New synapse release 1.7.0. + + -- Synapse Packaging team <packages@matrix.org> Fri, 13 Dec 2019 10:19:38 +0000 + +matrix-synapse-py3 (1.6.1) stable; urgency=medium + + * New synapse release 1.6.1. + + -- Synapse Packaging team <packages@matrix.org> Thu, 28 Nov 2019 11:10:40 +0000 + +matrix-synapse-py3 (1.6.0) stable; urgency=medium + + * New synapse release 1.6.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 26 Nov 2019 12:15:40 +0000 + +matrix-synapse-py3 (1.5.1) stable; urgency=medium + + * New synapse release 1.5.1. + + -- Synapse Packaging team <packages@matrix.org> Wed, 06 Nov 2019 10:02:14 +0000 + +matrix-synapse-py3 (1.5.0) stable; urgency=medium + + * New synapse release 1.5.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 29 Oct 2019 14:28:41 +0000 + +matrix-synapse-py3 (1.4.1) stable; urgency=medium + + * New synapse release 1.4.1. + + -- Synapse Packaging team <packages@matrix.org> Fri, 18 Oct 2019 10:13:27 +0100 + +matrix-synapse-py3 (1.4.0) stable; urgency=medium + + * New synapse release 1.4.0. + + -- Synapse Packaging team <packages@matrix.org> Thu, 03 Oct 2019 13:22:25 +0100 + matrix-synapse-py3 (1.3.1) stable; urgency=medium * New synapse release 1.3.1. diff --git a/debian/install b/debian/install index 43dc8c6904..da8b726a2b 100644 --- a/debian/install +++ b/debian/install @@ -1,2 +1 @@ -debian/log.yaml etc/matrix-synapse debian/manage_debconf.pl /opt/venvs/matrix-synapse/lib/ diff --git a/debian/log.yaml b/debian/log.yaml deleted file mode 100644 index 95b655dd35..0000000000 --- a/debian/log.yaml +++ /dev/null @@ -1,36 +0,0 @@ - -version: 1 - -formatters: - precise: - format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s- %(message)s' - -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - -handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: precise - filename: /var/log/matrix-synapse/homeserver.log - maxBytes: 104857600 - backupCount: 10 - filters: [context] - encoding: utf8 - console: - class: logging.StreamHandler - formatter: precise - level: WARN - -loggers: - synapse: - level: INFO - - synapse.storage.SQL: - level: INFO - -root: - level: INFO - handlers: [file, console] diff --git a/debian/po/templates.pot b/debian/po/templates.pot index 84d960761a..f0af9e70fb 100644 --- a/debian/po/templates.pot +++ b/debian/po/templates.pot @@ -1,14 +1,14 @@ # SOME DESCRIPTIVE TITLE. # Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER -# This file is distributed under the same license as the matrix-synapse package. +# This file is distributed under the same license as the matrix-synapse-py3 package. # FIRST AUTHOR <EMAIL@ADDRESS>, YEAR. # #, fuzzy msgid "" msgstr "" -"Project-Id-Version: matrix-synapse\n" -"Report-Msgid-Bugs-To: matrix-synapse@packages.debian.org\n" -"POT-Creation-Date: 2017-02-21 07:51+0000\n" +"Project-Id-Version: matrix-synapse-py3\n" +"Report-Msgid-Bugs-To: matrix-synapse-py3@packages.debian.org\n" +"POT-Creation-Date: 2020-04-06 16:39-0400\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME <EMAIL@ADDRESS>\n" "Language-Team: LANGUAGE <LL@li.org>\n" @@ -28,7 +28,10 @@ msgstr "" #: ../templates:1001 msgid "" "The name that this homeserver will appear as, to clients and other servers " -"via federation. This name should match the SRV record published in DNS." +"via federation. This is normally the public hostname of the server running " +"synapse, but can be different if you set up delegation. Please refer to the " +"delegation documentation in this case: https://github.com/matrix-org/synapse/" +"blob/master/docs/delegate.md." msgstr "" #. Type: boolean diff --git a/debian/rules b/debian/rules index a4d2ce2ba4..c744060a57 100755 --- a/debian/rules +++ b/debian/rules @@ -15,17 +15,38 @@ override_dh_installinit: # we don't really want to strip the symbols from our object files. override_dh_strip: +# dh_shlibdeps calls dpkg-shlibdeps, which finds all the binary files +# (executables and shared libs) in the package, and looks for the shared +# libraries that they depend on. It then adds a dependency on the package that +# contains that library to the package. +# +# We make two modifications to that process... +# override_dh_shlibdeps: - # make the postgres package's dependencies a recommendation - # rather than a hard dependency. + # Firstly, postgres is not a hard dependency for us, so we want to make + # the things that psycopg2 depends on (such as libpq) be + # recommendations rather than hard dependencies. We do so by + # running dpkg-shlibdeps manually on psycopg2's libs. + # find debian/$(PACKAGE_NAME)/ -path '*/site-packages/psycopg2/*.so' | \ xargs dpkg-shlibdeps -Tdebian/$(PACKAGE_NAME).substvars \ -pshlibs1 -dRecommends - # all the other dependencies can be normal 'Depends' requirements, - # except for PIL's, which is self-contained and which confuses - # dpkg-shlibdeps. - dh_shlibdeps -X site-packages/PIL/.libs -X site-packages/psycopg2 + # secondly, we exclude PIL's libraries from the process. They are known + # to be self-contained, but they have interdependencies and + # dpkg-shlibdeps doesn't know how to resolve them. + # + # As of Pillow 7.1.0, these libraries are in + # site-packages/Pillow.libs. Previously, they were in + # site-packages/PIL/.libs. + # + # (we also need to exclude psycopg2, of course, since we've already + # dealt with that.) + # + dh_shlibdeps \ + -X site-packages/PIL/.libs \ + -X site-packages/Pillow.libs \ + -X site-packages/psycopg2 override_dh_virtualenv: ./debian/build_virtualenv diff --git a/debian/templates b/debian/templates index 647358731c..458fe8bbe9 100644 --- a/debian/templates +++ b/debian/templates @@ -2,8 +2,10 @@ Template: matrix-synapse/server-name Type: string _Description: Name of the server: The name that this homeserver will appear as, to clients and other - servers via federation. This name should match the SRV record - published in DNS. + servers via federation. This is normally the public hostname of the + server running synapse, but can be different if you set up delegation. + Please refer to the delegation documentation in this case: + https://github.com/matrix-org/synapse/blob/master/docs/delegate.md. Template: matrix-synapse/report-stats Type: boolean diff --git a/demo/start.sh b/demo/start.sh index eccaa2abeb..83396e5c33 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -77,14 +77,13 @@ for port in 8080 8081 8082; do # Reduce the blacklist blacklist=$(cat <<-BLACK - # Set the blacklist so that it doesn't include 127.0.0.1 + # Set the blacklist so that it doesn't include 127.0.0.1, ::1 federation_ip_range_blacklist: - '10.0.0.0/8' - '172.16.0.0/12' - '192.168.0.0/16' - '100.64.0.0/10' - '169.254.0.0/16' - - '::1/128' - 'fe80::/64' - 'fc00::/7' BLACK diff --git a/docker/Dockerfile b/docker/Dockerfile index e5a0d6d5f6..9a3cf7b3f5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,7 +16,7 @@ ARG PYTHON_VERSION=3.7 ### ### Stage 0: builder ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.10 as builder +FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 as builder # install the OS build deps @@ -55,7 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \ ### Stage 1: runtime ### -FROM docker.io/python:${PYTHON_VERSION}-alpine3.10 +FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 # xmlsec is required for saml support RUN apk add --no-cache --virtual .runtime_deps \ diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index ac9ebcfd88..619585d5fa 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -27,15 +27,18 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \ wget # fetch and unpack the package -RUN wget -q -O /dh-virtuenv-1.1.tar.gz https://github.com/spotify/dh-virtualenv/archive/1.1.tar.gz -RUN tar xvf /dh-virtuenv-1.1.tar.gz +RUN mkdir /dh-virtualenv +RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz +RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz -# install its build deps -RUN cd dh-virtualenv-1.1/ \ - && env DEBIAN_FRONTEND=noninteractive mk-build-deps -ri -t "apt-get -yqq --no-install-recommends" +# install its build deps. We do another apt-cache-update here, because we might +# be using a stale cache from docker build. +RUN apt-get update -qq -o Acquire::Languages=none \ + && cd /dh-virtualenv \ + && env DEBIAN_FRONTEND=noninteractive mk-build-deps -ri -t "apt-get -y --no-install-recommends" # build it -RUN cd dh-virtualenv-1.1 && dpkg-buildpackage -us -uc -b +RUN cd /dh-virtualenv && dpkg-buildpackage -us -uc -b ### ### Stage 1 @@ -68,12 +71,12 @@ RUN apt-get update -qq -o Acquire::Languages=none \ sqlite3 \ libpq-dev -COPY --from=builder /dh-virtualenv_1.1-1_all.deb / +COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb / # install dhvirtualenv. Update the apt cache again first, in case we got a # cached cache from docker the first time. RUN apt-get update -qq -o Acquire::Languages=none \ - && apt-get install -yq /dh-virtualenv_1.1-1_all.deb + && apt-get install -yq /dh-virtualenv_1.2~dev-1_all.deb WORKDIR /synapse/source ENTRYPOINT ["bash","/synapse/source/docker/build_debian.sh"] diff --git a/docker/README.md b/docker/README.md index d5879c2f2c..8c337149ca 100644 --- a/docker/README.md +++ b/docker/README.md @@ -89,6 +89,8 @@ The following environment variables are supported in run mode: `/data`. * `SYNAPSE_CONFIG_PATH`: path to the config file. Defaults to `<SYNAPSE_CONFIG_DIR>/homeserver.yaml`. +* `SYNAPSE_WORKER`: module to execute, used when running synapse with workers. + Defaults to `synapse.app.homeserver`, which is suitable for non-worker mode. * `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`. * `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`. @@ -99,7 +101,7 @@ is suitable for local testing, but for any practical use, you will either need to use a reverse proxy, or configure Synapse to expose an HTTPS port. For documentation on using a reverse proxy, see -https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst. +https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.md. For more information on enabling TLS support in synapse itself, see https://github.com/matrix-org/synapse/blob/master/INSTALL.md#tls-certificates. Of @@ -108,12 +110,12 @@ argument to `docker run`. ## Legacy dynamic configuration file support -For backwards-compatibility only, the docker image supports creating a dynamic -configuration file based on environment variables. This is now deprecated, but -is enabled when the `SYNAPSE_SERVER_NAME` variable is set (and `generate` is -not given). +The docker image used to support creating a dynamic configuration file based +on environment variables. This is no longer supported, and an error will be +raised if you try to run synapse without a config file. -To migrate from a dynamic configuration file to a static one, run the docker +It is, however, possible to generate a static configuration file based on +the environment variables that were previously used. To do this, run the docker container once with the environment variables set, and `migrate_config` command line option. For example: @@ -125,6 +127,23 @@ docker run -it --rm \ matrixdotorg/synapse:latest migrate_config ``` -This will generate the same configuration file as the legacy mode used, but -will store it in `/data/homeserver.yaml` instead of a temporary location. You -can then use it as shown above at [Running synapse](#running-synapse). +This will generate the same configuration file as the legacy mode used, and +will store it in `/data/homeserver.yaml`. You can then use it as shown above at +[Running synapse](#running-synapse). + +Note that the defaults used in this configuration file may be different to +those when generating a new config file with `generate`: for example, TLS is +enabled by default in this mode. You are encouraged to inspect the generated +configuration file and edit it to ensure it meets your needs. + +## Building the image + +If you need to build the image from a Synapse checkout, use the following `docker + build` command from the repo's root: + +``` +docker build -t matrixdotorg/synapse -f docker/Dockerfile . +``` + +You can choose to build a different docker image by changing the value of the `-f` flag to +point to another Dockerfile. diff --git a/docker/conf/log.config b/docker/conf/log.config index db35e475a4..ed418a57cd 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -24,3 +24,5 @@ loggers: root: level: {{ SYNAPSE_LOG_LEVEL or "INFO" }} handlers: [console] + +disable_existing_loggers: false diff --git a/docker/start.py b/docker/start.py index 260f2d9943..2a25c9380e 100755 --- a/docker/start.py +++ b/docker/start.py @@ -169,11 +169,11 @@ def run_generate_config(environ, ownership): # log("running %s" % (args, )) if ownership is not None: - args = ["su-exec", ownership] + args - os.execv("/sbin/su-exec", args) - # make sure that synapse has perms to write to the data dir. subprocess.check_output(["chown", ownership, data_dir]) + + args = ["su-exec", ownership] + args + os.execv("/sbin/su-exec", args) else: os.execv("/usr/local/bin/python", args) @@ -182,16 +182,12 @@ def main(args, environ): mode = args[1] if len(args) > 1 else None desired_uid = int(environ.get("UID", "991")) desired_gid = int(environ.get("GID", "991")) + synapse_worker = environ.get("SYNAPSE_WORKER", "synapse.app.homeserver") if (desired_uid == os.getuid()) and (desired_gid == os.getgid()): ownership = None else: ownership = "{}:{}".format(desired_uid, desired_gid) - log( - "Container running as UserID %s:%s, ENV (or defaults) requests %s:%s" - % (os.getuid(), os.getgid(), desired_uid, desired_gid) - ) - if ownership is None: log("Will not perform chmod/su-exec as UserID already matches request") @@ -212,40 +208,33 @@ def main(args, environ): if mode is not None: error("Unknown execution mode '%s'" % (mode,)) - if "SYNAPSE_SERVER_NAME" in environ: - # backwards-compatibility generate-a-config-on-the-fly mode - if "SYNAPSE_CONFIG_PATH" in environ: + config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data") + config_path = environ.get("SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml") + + if not os.path.exists(config_path): + if "SYNAPSE_SERVER_NAME" in environ: error( - "SYNAPSE_SERVER_NAME and SYNAPSE_CONFIG_PATH are mutually exclusive " - "except in `generate` or `migrate_config` mode." + """\ +Config file '%s' does not exist. + +The synapse docker image no longer supports generating a config file on-the-fly +based on environment variables. You can migrate to a static config file by +running with 'migrate_config'. See the README for more details. +""" + % (config_path,) ) - config_path = "/compiled/homeserver.yaml" - log( - "Generating config file '%s' on-the-fly from environment variables.\n" - "Note that this mode is deprecated. You can migrate to a static config\n" - "file by running with 'migrate_config'. See the README for more details." + error( + "Config file '%s' does not exist. You should either create a new " + "config file by running with the `generate` argument (and then edit " + "the resulting file before restarting) or specify the path to an " + "existing config file with the SYNAPSE_CONFIG_PATH variable." % (config_path,) ) - generate_config_from_template("/compiled", config_path, environ, ownership) - else: - config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data") - config_path = environ.get( - "SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml" - ) - if not os.path.exists(config_path): - error( - "Config file '%s' does not exist. You should either create a new " - "config file by running with the `generate` argument (and then edit " - "the resulting file before restarting) or specify the path to an " - "existing config file with the SYNAPSE_CONFIG_PATH variable." - % (config_path,) - ) - log("Starting synapse with config file " + config_path) - args = ["python", "-m", "synapse.app.homeserver", "--config-path", config_path] + args = ["python", "-m", synapse_worker, "--config-path", config_path] if ownership is not None: args = ["su-exec", ownership] + args os.execv("/sbin/su-exec", args) diff --git a/docs/.sample_config_header.yaml b/docs/.sample_config_header.yaml index e001ef5983..35a591d042 100644 --- a/docs/.sample_config_header.yaml +++ b/docs/.sample_config_header.yaml @@ -1,4 +1,4 @@ -# The config is maintained as an up-to-date snapshot of the default +# This file is maintained as an up-to-date snapshot of the default # homeserver.yaml configuration generated by Synapse. # # It is intended to act as a reference for the default configuration, @@ -10,3 +10,5 @@ # homeserver.yaml. Instead, if you are starting from scratch, please generate # a fresh config using Synapse by following the instructions in INSTALL.md. +################################################################################ + diff --git a/docs/ACME.md b/docs/ACME.md index 9eb18a9cf5..f4c4740476 100644 --- a/docs/ACME.md +++ b/docs/ACME.md @@ -1,12 +1,48 @@ # ACME -Synapse v1.0 will require valid TLS certificates for communication between -servers (port `8448` by default) in addition to those that are client-facing -(port `443`). If you do not already have a valid certificate for your domain, -the easiest way to get one is with Synapse's new ACME support, which will use -the ACME protocol to provision a certificate automatically. Synapse v0.99.0+ -will provision server-to-server certificates automatically for you for free -through [Let's Encrypt](https://letsencrypt.org/) if you tell it to. +From version 1.0 (June 2019) onwards, Synapse requires valid TLS +certificates for communication between servers (by default on port +`8448`) in addition to those that are client-facing (port `443`). To +help homeserver admins fulfil this new requirement, Synapse v0.99.0 +introduced support for automatically provisioning certificates through +[Let's Encrypt](https://letsencrypt.org/) using the ACME protocol. + +## Deprecation of ACME v1 + +In [March 2019](https://community.letsencrypt.org/t/end-of-life-plan-for-acmev1/88430), +Let's Encrypt announced that they were deprecating version 1 of the ACME +protocol, with the plan to disable the use of it for new accounts in +November 2019, and for existing accounts in June 2020. + +Synapse doesn't currently support version 2 of the ACME protocol, which +means that: + +* for existing installs, Synapse's built-in ACME support will continue + to work until June 2020. +* for new installs, this feature will not work at all. + +Either way, it is recommended to move from Synapse's ACME support +feature to an external automated tool such as [certbot](https://github.com/certbot/certbot) +(or browse [this list](https://letsencrypt.org/fr/docs/client-options/) +for an alternative ACME client). + +It's also recommended to use a reverse proxy for the server-facing +communications (more documentation about this can be found +[here](/docs/reverse_proxy.md)) as well as the client-facing ones and +have it serve the certificates. + +In case you can't do that and need Synapse to serve them itself, make +sure to set the `tls_certificate_path` configuration setting to the path +of the certificate (make sure to use the certificate containing the full +certification chain, e.g. `fullchain.pem` if using certbot) and +`tls_private_key_path` to the path of the matching private key. Note +that in this case you will need to restart Synapse after each +certificate renewal so that Synapse stops using the old certificate. + +If you still want to use Synapse's built-in ACME support, the rest of +this document explains how to set it up. + +## Initial setup In the case that your `server_name` config variable is the same as the hostname that the client connects to, then the same certificate can be @@ -32,11 +68,6 @@ If you already have certificates, you will need to back up or delete them (files `example.com.tls.crt` and `example.com.tls.key` in Synapse's root directory), Synapse's ACME implementation will not overwrite them. -You may wish to use alternate methods such as Certbot to obtain a certificate -from Let's Encrypt, depending on your server configuration. Of course, if you -already have a valid certificate for your homeserver's domain, that can be -placed in Synapse's config directory without the need for any ACME setup. - ## ACME setup The main steps for enabling ACME support in short summary are: diff --git a/docs/CAPTCHA_SETUP.md b/docs/CAPTCHA_SETUP.md index 5f9057530b..331e5d059a 100644 --- a/docs/CAPTCHA_SETUP.md +++ b/docs/CAPTCHA_SETUP.md @@ -4,7 +4,7 @@ The captcha mechanism used is Google's ReCaptcha. This requires API keys from Go ## Getting keys -Requires a public/private key pair from: +Requires a site/secret key pair from: <https://developers.google.com/recaptcha/> @@ -15,8 +15,8 @@ Must be a reCAPTCHA v2 key using the "I'm not a robot" Checkbox option The keys are a config option on the home server config. If they are not visible, you can generate them via `--generate-config`. Set the following value: - recaptcha_public_key: YOUR_PUBLIC_KEY - recaptcha_private_key: YOUR_PRIVATE_KEY + recaptcha_public_key: YOUR_SITE_KEY + recaptcha_private_key: YOUR_SECRET_KEY In addition, you MUST enable captchas via: diff --git a/docs/admin_api/README.rst b/docs/admin_api/README.rst index d4f564cfae..9587bee0ce 100644 --- a/docs/admin_api/README.rst +++ b/docs/admin_api/README.rst @@ -4,9 +4,25 @@ Admin APIs This directory includes documentation for the various synapse specific admin APIs available. -Only users that are server admins can use these APIs. A user can be marked as a -server admin by updating the database directly, e.g.: +Authenticating as a server admin +-------------------------------- -``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'`` +Many of the API calls in the admin api will require an `access_token` for a +server admin. (Note that a server admin is distinct from a room admin.) -Restarting may be required for the changes to register. +A user can be marked as a server admin by updating the database directly, e.g.: + +.. code-block:: sql + + UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'; + +A new server admin user can also be created using the +``register_new_matrix_user`` script. + +Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings. + +Once you have your `access_token`, to include it in a request, the best option is to add the token to a request header: + +``curl --header "Authorization: Bearer <access_token>" <the_rest_of_your_API_request>`` + +Fore more details, please refer to the complete `matrix spec documentation <https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens>`_. diff --git a/docs/admin_api/delete_group.md b/docs/admin_api/delete_group.md index 1710488ea8..c061678e75 100644 --- a/docs/admin_api/delete_group.md +++ b/docs/admin_api/delete_group.md @@ -4,11 +4,11 @@ This API lets a server admin delete a local group. Doing so will kick all users out of the group so that their clients will correctly handle the group being deleted. - The API is: ``` POST /_synapse/admin/v1/delete_group/<group_id> ``` -including an `access_token` of a server admin. +To use it, you will need to authenticate by providing an `access_token` for a +server admin: see [README.rst](README.rst). diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 5e9f8e5d84..26948770d8 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -6,9 +6,10 @@ The API is: ``` GET /_synapse/admin/v1/room/<room_id>/media ``` -including an `access_token` of a server admin. +To use it, you will need to authenticate by providing an `access_token` for a +server admin: see [README.rst](README.rst). -It returns a JSON body like the following: +The API returns a JSON body like the following: ``` { "local": [ @@ -21,3 +22,81 @@ It returns a JSON body like the following: ] } ``` + +# Quarantine media + +Quarantining media means that it is marked as inaccessible by users. It applies +to any local media, and any locally-cached copies of remote media. + +The media file itself (and any thumbnails) is not deleted from the server. + +## Quarantining media by ID + +This API quarantines a single piece of local or remote media. + +Request: + +``` +POST /_synapse/admin/v1/media/quarantine/<server_name>/<media_id> + +{} +``` + +Where `server_name` is in the form of `example.org`, and `media_id` is in the +form of `abcdefg12345...`. + +Response: + +``` +{} +``` + +## Quarantining media in a room + +This API quarantines all local and remote media in a room. + +Request: + +``` +POST /_synapse/admin/v1/room/<room_id>/media/quarantine + +{} +``` + +Where `room_id` is in the form of `!roomid12345:example.org`. + +Response: + +``` +{ + "num_quarantined": 10 # The number of media items successfully quarantined +} +``` + +Note that there is a legacy endpoint, `POST +/_synapse/admin/v1/quarantine_media/<room_id >`, that operates the same. +However, it is deprecated and may be removed in a future release. + +## Quarantining all media of a user + +This API quarantines all *local* media that a *local* user has uploaded. That is to say, if +you would like to quarantine media uploaded by a user on a remote homeserver, you should +instead use one of the other APIs. + +Request: + +``` +POST /_synapse/admin/v1/user/<user_id>/media/quarantine + +{} +``` + +Where `user_id` is in the form of `@bob:example.org`. + +Response: + +``` +{ + "num_quarantined": 10 # The number of media items successfully quarantined +} +``` diff --git a/docs/admin_api/purge_history_api.rst b/docs/admin_api/purge_history_api.rst index f7be226fd9..92cd05f2a0 100644 --- a/docs/admin_api/purge_history_api.rst +++ b/docs/admin_api/purge_history_api.rst @@ -8,11 +8,15 @@ Depending on the amount of history being purged a call to the API may take several minutes or longer. During this period users will not be able to paginate further back in the room from the point being purged from. +Note that Synapse requires at least one message in each room, so it will never +delete the last message in a room. + The API is: ``POST /_synapse/admin/v1/purge_history/<room_id>[/<event_id>]`` -including an ``access_token`` of a server admin. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. By default, events sent by local users are not deleted, as they may represent the only copies of this content in existence. (Events sent by remote users are @@ -51,8 +55,10 @@ It is possible to poll for updates on recent purges with a second API; ``GET /_synapse/admin/v1/purge_history_status/<purge_id>`` -(again, with a suitable ``access_token``). This API returns a JSON body like -the following: +Again, you will need to authenticate by providing an ``access_token`` for a +server admin. + +This API returns a JSON body like the following: .. code:: json diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst index dacd5bc8fb..00cb6b0589 100644 --- a/docs/admin_api/purge_remote_media.rst +++ b/docs/admin_api/purge_remote_media.rst @@ -6,12 +6,15 @@ media. The API is:: - POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>&access_token=<access_token> + POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms> {} -Which will remove all cached media that was last accessed before +\... which will remove all cached media that was last accessed before ``<unix_timestamp_in_ms>``. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + If the user re-requests purged remote media, synapse will re-request the media from the originating server. diff --git a/docs/admin_api/room_membership.md b/docs/admin_api/room_membership.md new file mode 100644 index 0000000000..b6746ff5e4 --- /dev/null +++ b/docs/admin_api/room_membership.md @@ -0,0 +1,35 @@ +# Edit Room Membership API + +This API allows an administrator to join an user account with a given `user_id` +to a room with a given `room_id_or_alias`. You can only modify the membership of +local users. The server administrator must be in the room and have permission to +invite users. + +## Parameters + +The following parameters are available: + +* `user_id` - Fully qualified user: for example, `@user:server.com`. +* `room_id_or_alias` - The room identifier or alias to join: for example, + `!636q39766251:server.com`. + +## Usage + +``` +POST /_synapse/admin/v1/join/<room_id_or_alias> + +{ + "user_id": "@user:server.com" +} +``` + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: see [README.rst](README.rst). + +Response: + +``` +{ + "room_id": "!636q39766251:server.com" +} +``` diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md new file mode 100644 index 0000000000..624e7745ba --- /dev/null +++ b/docs/admin_api/rooms.md @@ -0,0 +1,320 @@ +# List Room API + +The List Room admin API allows server admins to get a list of rooms on their +server. There are various parameters available that allow for filtering and +sorting the returned list. This API supports pagination. + +## Parameters + +The following query parameters are available: + +* `from` - Offset in the returned list. Defaults to `0`. +* `limit` - Maximum amount of rooms to return. Defaults to `100`. +* `order_by` - The method in which to sort the returned list of rooms. Valid values are: + - `alphabetical` - Same as `name`. This is deprecated. + - `size` - Same as `joined_members`. This is deprecated. + - `name` - Rooms are ordered alphabetically by room name. This is the default. + - `canonical_alias` - Rooms are ordered alphabetically by main alias address of the room. + - `joined_members` - Rooms are ordered by the number of members. Largest to smallest. + - `joined_local_members` - Rooms are ordered by the number of local members. Largest to smallest. + - `version` - Rooms are ordered by room version. Largest to smallest. + - `creator` - Rooms are ordered alphabetically by creator of the room. + - `encryption` - Rooms are ordered alphabetically by the end-to-end encryption algorithm. + - `federatable` - Rooms are ordered by whether the room is federatable. + - `public` - Rooms are ordered by visibility in room list. + - `join_rules` - Rooms are ordered alphabetically by join rules of the room. + - `guest_access` - Rooms are ordered alphabetically by guest access option of the room. + - `history_visibility` - Rooms are ordered alphabetically by visibility of history of the room. + - `state_events` - Rooms are ordered by number of state events. Largest to smallest. +* `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting + this value to `b` will reverse the above sort order. Defaults to `f`. +* `search_term` - Filter rooms by their room name. Search term can be contained in any + part of the room name. Defaults to no filtering. + +The following fields are possible in the JSON response body: + +* `rooms` - An array of objects, each containing information about a room. + - Room objects contain the following fields: + - `room_id` - The ID of the room. + - `name` - The name of the room. + - `canonical_alias` - The canonical (main) alias address of the room. + - `joined_members` - How many users are currently in the room. + - `joined_local_members` - How many local users are currently in the room. + - `version` - The version of the room as a string. + - `creator` - The `user_id` of the room creator. + - `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. + - `federatable` - Whether users on other servers can join this room. + - `public` - Whether the room is visible in room directory. + - `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"]. + - `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"]. + - `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. + - `state_events` - Total number of state_events of a room. Complexity of the room. +* `offset` - The current pagination offset in rooms. This parameter should be + used instead of `next_token` for room offset as `next_token` is + not intended to be parsed. +* `total_rooms` - The total number of rooms this query can return. Using this + and `offset`, you have enough information to know the current + progression through the list. +* `next_batch` - If this field is present, we know that there are potentially + more rooms on the server that did not all fit into this response. + We can use `next_batch` to get the "next page" of results. To do + so, simply repeat your request, setting the `from` parameter to + the value of `next_batch`. +* `prev_batch` - If this field is present, it is possible to paginate backwards. + Use `prev_batch` for the `from` value in the next request to + get the "previous page" of results. + +## Usage + +A standard request with no filtering: + +``` +GET /_synapse/admin/v1/rooms + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "name": "Matrix HQ", + "canonical_alias": "#matrix:matrix.org", + "joined_members": 8326, + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 + }, + ... (8 hidden items) ... + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314, + "joined_local_members": 20, + "version": "4", + "creator": "@foo:matrix.org", + "encryption": "m.megolm.v1.aes-sha2", + "federatable": true, + "public": false, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8345 + } + ], + "offset": 0, + "total_rooms": 10 +} +``` + +Filtering by room name: + +``` +GET /_synapse/admin/v1/rooms?search_term=TWIM + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314, + "joined_local_members": 20, + "version": "4", + "creator": "@foo:matrix.org", + "encryption": "m.megolm.v1.aes-sha2", + "federatable": true, + "public": false, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8 + } + ], + "offset": 0, + "total_rooms": 1 +} +``` + +Paginating through a list of rooms: + +``` +GET /_synapse/admin/v1/rooms?order_by=size + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!OGEhHVWSdvArJzumhm:matrix.org", + "name": "Matrix HQ", + "canonical_alias": "#matrix:matrix.org", + "joined_members": 8326, + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 + }, + ... (98 hidden items) ... + { + "room_id": "!xYvNcQPhnkrdUmYczI:matrix.org", + "name": "This Week In Matrix (TWIM)", + "canonical_alias": "#twim:matrix.org", + "joined_members": 314, + "joined_local_members": 20, + "version": "4", + "creator": "@foo:matrix.org", + "encryption": "m.megolm.v1.aes-sha2", + "federatable": true, + "public": false, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8345 + } + ], + "offset": 0, + "total_rooms": 150 + "next_token": 100 +} +``` + +The presence of the `next_token` parameter tells us that there are more rooms +than returned in this request, and we need to make another request to get them. +To get the next batch of room results, we repeat our request, setting the `from` +parameter to the value of `next_token`. + +``` +GET /_synapse/admin/v1/rooms?order_by=size&from=100 + +{} +``` + +Response: + +``` +{ + "rooms": [ + { + "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", + "name": "Music Theory", + "canonical_alias": "#musictheory:matrix.org", + "joined_members": 127 + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 + }, + ... (48 hidden items) ... + { + "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk", + "name": "weechat-matrix", + "canonical_alias": "#weechat-matrix:termina.org.uk", + "joined_members": 137 + "joined_local_members": 20, + "version": "4", + "creator": "@foo:termina.org.uk", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 8345 + } + ], + "offset": 100, + "prev_batch": 0, + "total_rooms": 150 +} +``` + +Once the `next_token` parameter is no longer present, we know we've reached the +end of the list. + +# DRAFT: Room Details API + +The Room Details admin API allows server admins to get all details of a room. + +This API is still a draft and details might change! + +The following fields are possible in the JSON response body: + +* `room_id` - The ID of the room. +* `name` - The name of the room. +* `canonical_alias` - The canonical (main) alias address of the room. +* `joined_members` - How many users are currently in the room. +* `joined_local_members` - How many local users are currently in the room. +* `version` - The version of the room as a string. +* `creator` - The `user_id` of the room creator. +* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. +* `federatable` - Whether users on other servers can join this room. +* `public` - Whether the room is visible in room directory. +* `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"]. +* `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"]. +* `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. +* `state_events` - Total number of state_events of a room. Complexity of the room. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms/<room_id> + +{} +``` + +Response: + +``` +{ + "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", + "name": "Music Theory", + "canonical_alias": "#musictheory:matrix.org", + "joined_members": 127 + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 +} +``` diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md new file mode 100644 index 0000000000..54ce1cd234 --- /dev/null +++ b/docs/admin_api/shutdown_room.md @@ -0,0 +1,72 @@ +# Shutdown room API + +Shuts down a room, preventing new joins and moves local users and room aliases automatically +to a new room. The new room will be created with the user specified by the +`new_room_user_id` parameter as room administrator and will contain a message +explaining what happened. Users invited to the new room will have power level +-10 by default, and thus be unable to speak. The old room's power levels will be changed to +disallow any further invites or joins. + +The local server will only have the power to move local user and room aliases to +the new room. Users on other servers will be unaffected. + +## API + +You will need to authenticate with an access token for an admin user. + +### URL + +`POST /_synapse/admin/v1/shutdown_room/{room_id}` + +### URL Parameters + +* `room_id` - The ID of the room (e.g `!someroom:example.com`) + +### JSON Body Parameters + +* `new_room_user_id` - Required. A string representing the user ID of the user that will admin + the new room that all users in the old room will be moved to. +* `room_name` - Optional. A string representing the name of the room that new users will be + invited to. +* `message` - Optional. A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly convey why the + original room was shut down. + +If not specified, the default value of `room_name` is "Content Violation +Notification". The default value of `message` is "Sharing illegal content on +othis server is not permitted and rooms in violation will be blocked." + +### Response Parameters + +* `kicked_users` - An integer number representing the number of users that + were kicked. +* `failed_to_kick_users` - An integer number representing the number of users + that were not kicked. +* `local_aliases` - An array of strings representing the local aliases that were migrated from + the old room to the new. +* `new_room_id` - A string representing the room ID of the new room. + +## Example + +Request: + +``` +POST /_synapse/admin/v1/shutdown_room/!somebadroom%3Aexample.com + +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service." +} +``` + +Response: + +``` +{ + "kicked_users": 5, + "failed_to_kick_users": 0, + "local_aliases": ["#badroom:example.com", "#evilsaloon:example.com], + "new_room_id": "!newroomid:example.com", +}, +``` diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index d0871f9438..7b030a6285 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -1,13 +1,176 @@ -Query Account -============= +.. contents:: + +Query User Account +================== This API returns information about a specific user account. The api is:: + GET /_synapse/admin/v2/users/<user_id> + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +It returns a JSON body like the following: + +.. code:: json + + { + "displayname": "User", + "threepids": [ + { + "medium": "email", + "address": "<user_mail_1>" + }, + { + "medium": "email", + "address": "<user_mail_2>" + } + ], + "avatar_url": "<avatar_url>", + "admin": false, + "deactivated": false + } + +URL parameters: + +- ``user_id``: fully-qualified user id: for example, ``@user:server.com``. + +Create or modify Account +======================== + +This API allows an administrator to create or modify a user account with a +specific ``user_id``. + +This api is:: + + PUT /_synapse/admin/v2/users/<user_id> + +with a body of: + +.. code:: json + + { + "password": "user_password", + "displayname": "User", + "threepids": [ + { + "medium": "email", + "address": "<user_mail_1>" + }, + { + "medium": "email", + "address": "<user_mail_2>" + } + ], + "avatar_url": "<avatar_url>", + "admin": false, + "deactivated": false + } + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +URL parameters: + +- ``user_id``: fully-qualified user id: for example, ``@user:server.com``. + +Body parameters: + +- ``password``, optional. If provided, the user's password is updated and all + devices are logged out. + +- ``displayname``, optional, defaults to the value of ``user_id``. + +- ``threepids``, optional, allows setting the third-party IDs (email, msisdn) + belonging to a user. + +- ``avatar_url``, optional, must be a + `MXC URI <https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris>`_. + +- ``admin``, optional, defaults to ``false``. + +- ``deactivated``, optional, defaults to ``false``. + +If the user already exists then optional parameters default to the current value. + +List Accounts +============= + +This API returns all local user accounts. + +The api is:: + + GET /_synapse/admin/v2/users?from=0&limit=10&guests=false + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: see `README.rst <README.rst>`_. + +The parameter ``from`` is optional but used for pagination, denoting the +offset in the returned results. This should be treated as an opaque value and +not explicitly set to anything other than the return value of ``next_token`` +from a previous call. + +The parameter ``limit`` is optional but is used for pagination, denoting the +maximum number of items to return in this call. Defaults to ``100``. + +The parameter ``user_id`` is optional and filters to only users with user IDs +that contain this value. + +The parameter ``guests`` is optional and if ``false`` will **exclude** guest users. +Defaults to ``true`` to include guest users. + +The parameter ``deactivated`` is optional and if ``true`` will **include** deactivated users. +Defaults to ``false`` to exclude deactivated users. + +A JSON body is returned with the following shape: + +.. code:: json + + { + "users": [ + { + "name": "<user_id1>", + "password_hash": "<password_hash1>", + "is_guest": 0, + "admin": 0, + "user_type": null, + "deactivated": 0, + "displayname": "<User One>", + "avatar_url": null + }, { + "name": "<user_id2>", + "password_hash": "<password_hash2>", + "is_guest": 0, + "admin": 1, + "user_type": null, + "deactivated": 0, + "displayname": "<User Two>", + "avatar_url": "<avatar_url>" + } + ], + "next_token": "100", + "total": 200 + } + +To paginate, check for ``next_token`` and if present, call the endpoint again +with ``from`` set to the value of ``next_token``. This will return a new page. + +If the endpoint does not return a ``next_token`` then there are no more users +to paginate through. + +Query current sessions for a user +================================= + +This API returns information about the active sessions for a specific user. + +The api is:: + GET /_synapse/admin/v1/whois/<user_id> -including an ``access_token`` of a server admin. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. It returns a JSON body like the following: @@ -60,9 +223,10 @@ with a body of: "erase": true } -including an ``access_token`` of a server admin. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. -The erase parameter is optional and defaults to 'false'. +The erase parameter is optional and defaults to ``false``. An empty body may be passed for backwards compatibility. @@ -80,11 +244,15 @@ with a body of: .. code:: json { - "new_password": "<secret>" + "new_password": "<secret>", + "logout_devices": true, } -including an ``access_token`` of a server admin. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. +The parameter ``new_password`` is required. +The parameter ``logout_devices`` is optional and defaults to ``true``. Get whether a user is a server administrator or not =================================================== @@ -94,7 +262,8 @@ The api is:: GET /_synapse/admin/v1/users/<user_id>/admin -including an ``access_token`` of a server admin. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. A response body like the following is returned: @@ -122,4 +291,191 @@ with a body of: "admin": true } -including an ``access_token`` of a server admin. +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + + +User devices +============ + +List all devices +---------------- +Gets information about all devices for a specific ``user_id``. + +The API is:: + + GET /_synapse/admin/v2/users/<user_id>/devices + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +A response body like the following is returned: + +.. code:: json + + { + "devices": [ + { + "device_id": "QBUAZIFURK", + "display_name": "android", + "last_seen_ip": "1.2.3.4", + "last_seen_ts": 1474491775024, + "user_id": "<user_id>" + }, + { + "device_id": "AUIECTSRND", + "display_name": "ios", + "last_seen_ip": "1.2.3.5", + "last_seen_ts": 1474491775025, + "user_id": "<user_id>" + } + ] + } + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. + +**Response** + +The following fields are returned in the JSON response body: + +- ``devices`` - An array of objects, each containing information about a device. + Device objects contain the following fields: + + - ``device_id`` - Identifier of device. + - ``display_name`` - Display name set by the user for this device. + Absent if no name has been set. + - ``last_seen_ip`` - The IP address where this device was last seen. + (May be a few minutes out of date, for efficiency reasons). + - ``last_seen_ts`` - The timestamp (in milliseconds since the unix epoch) when this + devices was last seen. (May be a few minutes out of date, for efficiency reasons). + - ``user_id`` - Owner of device. + +Delete multiple devices +------------------ +Deletes the given devices for a specific ``user_id``, and invalidates +any access token associated with them. + +The API is:: + + POST /_synapse/admin/v2/users/<user_id>/delete_devices + + { + "devices": [ + "QBUAZIFURK", + "AUIECTSRND" + ], + } + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +An empty JSON dict is returned. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. + +The following fields are required in the JSON request body: + +- ``devices`` - The list of device IDs to delete. + +Show a device +--------------- +Gets information on a single device, by ``device_id`` for a specific ``user_id``. + +The API is:: + + GET /_synapse/admin/v2/users/<user_id>/devices/<device_id> + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +A response body like the following is returned: + +.. code:: json + + { + "device_id": "<device_id>", + "display_name": "android", + "last_seen_ip": "1.2.3.4", + "last_seen_ts": 1474491775024, + "user_id": "<user_id>" + } + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. +- ``device_id`` - The device to retrieve. + +**Response** + +The following fields are returned in the JSON response body: + +- ``device_id`` - Identifier of device. +- ``display_name`` - Display name set by the user for this device. + Absent if no name has been set. +- ``last_seen_ip`` - The IP address where this device was last seen. + (May be a few minutes out of date, for efficiency reasons). +- ``last_seen_ts`` - The timestamp (in milliseconds since the unix epoch) when this + devices was last seen. (May be a few minutes out of date, for efficiency reasons). +- ``user_id`` - Owner of device. + +Update a device +--------------- +Updates the metadata on the given ``device_id`` for a specific ``user_id``. + +The API is:: + + PUT /_synapse/admin/v2/users/<user_id>/devices/<device_id> + + { + "display_name": "My other phone" + } + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +An empty JSON dict is returned. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. +- ``device_id`` - The device to update. + +The following fields are required in the JSON request body: + +- ``display_name`` - The new display name for this device. If not given, + the display name is unchanged. + +Delete a device +--------------- +Deletes the given ``device_id`` for a specific ``user_id``, +and invalidates any access token associated with it. + +The API is:: + + DELETE /_synapse/admin/v2/users/<user_id>/devices/<device_id> + + {} + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +An empty JSON dict is returned. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. +- ``device_id`` - The device to delete. diff --git a/docs/application_services.md b/docs/application_services.md index 06cb79f1f9..e4592010a2 100644 --- a/docs/application_services.md +++ b/docs/application_services.md @@ -23,9 +23,13 @@ namespaces: users: # List of users we're interested in - exclusive: <bool> regex: <regex> + group_id: <group> - ... aliases: [] # List of aliases we're interested in rooms: [] # List of room ids we're interested in ``` +`exclusive`: If enabled, only this application service is allowed to register users in its namespace(s). +`group_id`: All users of this application service are dynamically joined to this group. This is useful for e.g user organisation or flairs. + See the [spec](https://matrix.org/docs/spec/application_service/unstable.html) for further details on how application services work. diff --git a/docs/code_style.md b/docs/code_style.md index f983f72d6c..6ef6f80290 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -30,7 +30,7 @@ The necessary tools are detailed below. Install `flake8` with: - pip install --upgrade flake8 + pip install --upgrade flake8 flake8-comprehensions Check all application and test code with: @@ -137,6 +137,7 @@ Some guidelines follow: correctly handles the top-level option being set to `None` (as it will be if no sub-options are enabled). - Lines should be wrapped at 80 characters. +- Use two-space indents. Example: @@ -155,13 +156,13 @@ Example: # Settings for the frobber # frobber: - # frobbing speed. Defaults to 1. - # - #speed: 10 + # frobbing speed. Defaults to 1. + # + #speed: 10 - # frobbing distance. Defaults to 1000. - # - #distance: 100 + # frobbing distance. Defaults to 1000. + # + #distance: 100 Note that the sample configuration is generated from the synapse code and is maintained by a script, `scripts-dev/generate_sample_config`. diff --git a/docs/delegate.md b/docs/delegate.md new file mode 100644 index 0000000000..208ddb6277 --- /dev/null +++ b/docs/delegate.md @@ -0,0 +1,94 @@ +# Delegation + +By default, other homeservers will expect to be able to reach yours via +your `server_name`, on port 8448. For example, if you set your `server_name` +to `example.com` (so that your user names look like `@user:example.com`), +other servers will try to connect to yours at `https://example.com:8448/`. + +Delegation is a Matrix feature allowing a homeserver admin to retain a +`server_name` of `example.com` so that user IDs, room aliases, etc continue +to look like `*:example.com`, whilst having federation traffic routed +to a different server and/or port (e.g. `synapse.example.com:443`). + +## .well-known delegation + +To use this method, you need to be able to alter the +`server_name` 's https server to serve the `/.well-known/matrix/server` +URL. Having an active server (with a valid TLS certificate) serving your +`server_name` domain is out of the scope of this documentation. + +The URL `https://<server_name>/.well-known/matrix/server` should +return a JSON structure containing the key `m.server` like so: + +```json +{ + "m.server": "<synapse.server.name>[:<yourport>]" +} +``` + +In our example, this would mean that URL `https://example.com/.well-known/matrix/server` +should return: + +```json +{ + "m.server": "synapse.example.com:443" +} +``` + +Note, specifying a port is optional. If no port is specified, then it defaults +to 8448. + +With .well-known delegation, federating servers will check for a valid TLS +certificate for the delegated hostname (in our example: `synapse.example.com`). + +## SRV DNS record delegation + +It is also possible to do delegation using a SRV DNS record. However, that is +considered an advanced topic since it's a bit complex to set up, and `.well-known` +delegation is already enough in most cases. + +However, if you really need it, you can find some documentation on how such a +record should look like and how Synapse will use it in [the Matrix +specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names). + +## Delegation FAQ + +### When do I need delegation? + +If your homeserver's APIs are accessible on the default federation port (8448) +and the domain your `server_name` points to, you do not need any delegation. + +For instance, if you registered `example.com` and pointed its DNS A record at a +fresh server, you could install Synapse on that host, giving it a `server_name` +of `example.com`, and once a reverse proxy has been set up to proxy all requests +sent to the port `8448` and serve TLS certificates for `example.com`, you +wouldn't need any delegation set up. + +**However**, if your homeserver's APIs aren't accessible on port 8448 and on the +domain `server_name` points to, you will need to let other servers know how to +find it using delegation. + +### Do you still recommend against using a reverse proxy on the federation port? + +We no longer actively recommend against using a reverse proxy. Many admins will +find it easier to direct federation traffic to a reverse proxy and manage their +own TLS certificates, and this is a supported configuration. + +See [reverse_proxy.md](reverse_proxy.md) for information on setting up a +reverse proxy. + +### Do I still need to give my TLS certificates to Synapse if I am using a reverse proxy? + +This is no longer necessary. If you are using a reverse proxy for all of your +TLS traffic, then you can set `no_tls: True` in the Synapse config. + +In that case, the only reason Synapse needs the certificate is to populate a legacy +`tls_fingerprints` field in the federation API. This is ignored by Synapse 0.99.0 +and later, and the only time pre-0.99 Synapses will check it is when attempting to +fetch the server keys - and generally this is delegated via `matrix.org`, which +is running a modern version of Synapse. + +### Do I need the same certificate for the client and federation port? + +No. There is nothing stopping you from using different certificates, +particularly if you are using a reverse proxy. \ No newline at end of file diff --git a/docs/dev/cas.md b/docs/dev/cas.md new file mode 100644 index 0000000000..f8d02cc82c --- /dev/null +++ b/docs/dev/cas.md @@ -0,0 +1,64 @@ +# How to test CAS as a developer without a server + +The [django-mama-cas](https://github.com/jbittel/django-mama-cas) project is an +easy to run CAS implementation built on top of Django. + +## Prerequisites + +1. Create a new virtualenv: `python3 -m venv <your virtualenv>` +2. Activate your virtualenv: `source /path/to/your/virtualenv/bin/activate` +3. Install Django and django-mama-cas: + ``` + python -m pip install "django<3" "django-mama-cas==2.4.0" + ``` +4. Create a Django project in the current directory: + ``` + django-admin startproject cas_test . + ``` +5. Follow the [install directions](https://django-mama-cas.readthedocs.io/en/latest/installation.html#configuring) for django-mama-cas +6. Setup the SQLite database: `python manage.py migrate` +7. Create a user: + ``` + python manage.py createsuperuser + ``` + 1. Use whatever you want as the username and password. + 2. Leave the other fields blank. +8. Use the built-in Django test server to serve the CAS endpoints on port 8000: + ``` + python manage.py runserver + ``` + +You should now have a Django project configured to serve CAS authentication with +a single user created. + +## Configure Synapse (and Riot) to use CAS + +1. Modify your `homeserver.yaml` to enable CAS and point it to your locally + running Django test server: + ```yaml + cas_config: + enabled: true + server_url: "http://localhost:8000" + service_url: "http://localhost:8081" + #displayname_attribute: name + #required_attributes: + # name: value + ``` +2. Restart Synapse. + +Note that the above configuration assumes the homeserver is running on port 8081 +and that the CAS server is on port 8000, both on localhost. + +## Testing the configuration + +Then in Riot: + +1. Visit the login page with a Riot pointing at your homeserver. +2. Click the Single Sign-On button. +3. Login using the credentials created with `createsuperuser`. +4. You should be logged in. + +If you want to repeat this process you'll need to manually logout first: + +1. http://localhost:8000/admin/ +2. Click "logout" in the top right. diff --git a/docs/dev/git.md b/docs/dev/git.md new file mode 100644 index 0000000000..b747ff20c9 --- /dev/null +++ b/docs/dev/git.md @@ -0,0 +1,148 @@ +Some notes on how we use git +============================ + +On keeping the commit history clean +----------------------------------- + +In an ideal world, our git commit history would be a linear progression of +commits each of which contains a single change building on what came +before. Here, by way of an arbitrary example, is the top of `git log --graph +b2dba0607`: + +<img src="git/clean.png" alt="clean git graph" width="500px"> + +Note how the commit comment explains clearly what is changing and why. Also +note the *absence* of merge commits, as well as the absence of commits called +things like (to pick a few culprits): +[“pep8”](https://github.com/matrix-org/synapse/commit/84691da6c), [“fix broken +test”](https://github.com/matrix-org/synapse/commit/474810d9d), +[“oops”](https://github.com/matrix-org/synapse/commit/c9d72e457), +[“typo”](https://github.com/matrix-org/synapse/commit/836358823), or [“Who's +the president?”](https://github.com/matrix-org/synapse/commit/707374d5d). + +There are a number of reasons why keeping a clean commit history is a good +thing: + + * From time to time, after a change lands, it turns out to be necessary to + revert it, or to backport it to a release branch. Those operations are + *much* easier when the change is contained in a single commit. + + * Similarly, it's much easier to answer questions like “is the fix for + `/publicRooms` on the release branch?” if that change consists of a single + commit. + + * Likewise: “what has changed on this branch in the last week?” is much + clearer without merges and “pep8” commits everywhere. + + * Sometimes we need to figure out where a bug got introduced, or some + behaviour changed. One way of doing that is with `git bisect`: pick an + arbitrary commit between the known good point and the known bad point, and + see how the code behaves. However, that strategy fails if the commit you + chose is the middle of someone's epic branch in which they broke the world + before putting it back together again. + +One counterargument is that it is sometimes useful to see how a PR evolved as +it went through review cycles. This is true, but that information is always +available via the GitHub UI (or via the little-known [refs/pull +namespace](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/checking-out-pull-requests-locally)). + + +Of course, in reality, things are more complicated than that. We have release +branches as well as `develop` and `master`, and we deliberately merge changes +between them. Bugs often slip through and have to be fixed later. That's all +fine: this not a cast-iron rule which must be obeyed, but an ideal to aim +towards. + +Merges, squashes, rebases: wtf? +------------------------------- + +Ok, so that's what we'd like to achieve. How do we achieve it? + +The TL;DR is: when you come to merge a pull request, you *probably* want to +“squash and merge”: + +![squash and merge](git/squash.png). + +(This applies whether you are merging your own PR, or that of another +contributor.) + +“Squash and merge”<sup id="a1">[1](#f1)</sup> takes all of the changes in the +PR, and bundles them into a single commit. GitHub gives you the opportunity to +edit the commit message before you confirm, and normally you should do so, +because the default will be useless (again: `* woops typo` is not a useful +thing to keep in the historical record). + +The main problem with this approach comes when you have a series of pull +requests which build on top of one another: as soon as you squash-merge the +first PR, you'll end up with a stack of conflicts to resolve in all of the +others. In general, it's best to avoid this situation in the first place by +trying not to have multiple related PRs in flight at the same time. Still, +sometimes that's not possible and doing a regular merge is the lesser evil. + +Another occasion in which a regular merge makes more sense is a PR where you've +deliberately created a series of commits each of which makes sense in its own +right. For example: [a PR which gradually propagates a refactoring operation +through the codebase](https://github.com/matrix-org/synapse/pull/6837), or [a +PR which is the culmination of several other +PRs](https://github.com/matrix-org/synapse/pull/5987). In this case the ability +to figure out when a particular change/bug was introduced could be very useful. + +Ultimately: **this is not a hard-and-fast-rule**. If in doubt, ask yourself “do +each of the commits I am about to merge make sense in their own right”, but +remember that we're just doing our best to balance “keeping the commit history +clean” with other factors. + +Git branching model +------------------- + +A [lot](https://nvie.com/posts/a-successful-git-branching-model/) +[of](http://scottchacon.com/2011/08/31/github-flow.html) +[words](https://www.endoflineblog.com/gitflow-considered-harmful) have been +written in the past about git branching models (no really, [a +lot](https://martinfowler.com/articles/branching-patterns.html)). I tend to +think the whole thing is overblown. Fundamentally, it's not that +complicated. Here's how we do it. + +Let's start with a picture: + +![branching model](git/branches.jpg) + +It looks complicated, but it's really not. There's one basic rule: *anyone* is +free to merge from *any* more-stable branch to *any* less-stable branch at +*any* time<sup id="a2">[2](#f2)</sup>. (The principle behind this is that if a +change is good enough for the more-stable branch, then it's also good enough go +put in a less-stable branch.) + +Meanwhile, merging (or squashing, as per the above) from a less-stable to a +more-stable branch is a deliberate action in which you want to publish a change +or a set of changes to (some subset of) the world: for example, this happens +when a PR is landed, or as part of our release process. + +So, what counts as a more- or less-stable branch? A little reflection will show +that our active branches are ordered thus, from more-stable to less-stable: + + * `master` (tracks our last release). + * `release-vX.Y.Z` (the branch where we prepare the next release)<sup + id="a3">[3](#f3)</sup>. + * PR branches which are targeting the release. + * `develop` (our "mainline" branch containing our bleeding-edge). + * regular PR branches. + +The corollary is: if you have a bugfix that needs to land in both +`release-vX.Y.Z` *and* `develop`, then you should base your PR on +`release-vX.Y.Z`, get it merged there, and then merge from `release-vX.Y.Z` to +`develop`. (If a fix lands in `develop` and we later need it in a +release-branch, we can of course cherry-pick it, but landing it in the release +branch first helps reduce the chance of annoying conflicts.) + +--- + +<b id="f1">[1]</b>: “Squash and merge” is GitHub's term for this +operation. Given that there is no merge involved, I'm not convinced it's the +most intuitive name. [^](#a1) + +<b id="f2">[2]</b>: Well, anyone with commit access.[^](#a2) + +<b id="f3">[3]</b>: Very, very occasionally (I think this has happened once in +the history of Synapse), we've had two releases in flight at once. Obviously, +`release-v1.2.3` is more-stable than `release-v1.3.0`. [^](#a3) diff --git a/docs/dev/git/branches.jpg b/docs/dev/git/branches.jpg new file mode 100644 index 0000000000..715ecc8cd0 --- /dev/null +++ b/docs/dev/git/branches.jpg Binary files differdiff --git a/docs/dev/git/clean.png b/docs/dev/git/clean.png new file mode 100644 index 0000000000..3accd7ccef --- /dev/null +++ b/docs/dev/git/clean.png Binary files differdiff --git a/docs/dev/git/squash.png b/docs/dev/git/squash.png new file mode 100644 index 0000000000..234caca3e4 --- /dev/null +++ b/docs/dev/git/squash.png Binary files differdiff --git a/docs/dev/saml.md b/docs/dev/saml.md index f41aadce47..a9bfd2dc05 100644 --- a/docs/dev/saml.md +++ b/docs/dev/saml.md @@ -18,9 +18,13 @@ To make Synapse (and therefore Riot) use it: metadata: local: ["samling.xml"] ``` -5. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure +5. Ensure that your `homeserver.yaml` has a setting for `public_baseurl`: + ```yaml + public_baseurl: http://localhost:8080/ + ``` +6. Run `apt-get install xmlsec1` and `pip install --upgrade --force 'pysaml2>=4.5.0'` to ensure the dependencies are installed and ready to go. -6. Restart Synapse. +7. Restart Synapse. Then in Riot: diff --git a/docs/federate.md b/docs/federate.md index 193e2d2dfe..a0786b9cf7 100644 --- a/docs/federate.md +++ b/docs/federate.md @@ -1,181 +1,41 @@ -Setting up Federation +Setting up federation ===================== Federation is the process by which users on different servers can participate in the same room. For this to work, those other servers must be able to contact yours to send messages. -The ``server_name`` configured in the Synapse configuration file (often -``homeserver.yaml``) defines how resources (users, rooms, etc.) will be -identified (eg: ``@user:example.com``, ``#room:example.com``). By -default, it is also the domain that other servers will use to -try to reach your server (via port 8448). This is easy to set -up and will work provided you set the ``server_name`` to match your -machine's public DNS hostname, and provide Synapse with a TLS certificate -which is valid for your ``server_name``. +The `server_name` configured in the Synapse configuration file (often +`homeserver.yaml`) defines how resources (users, rooms, etc.) will be +identified (eg: `@user:example.com`, `#room:example.com`). By default, +it is also the domain that other servers will use to try to reach your +server (via port 8448). This is easy to set up and will work provided +you set the `server_name` to match your machine's public DNS hostname. + +For this default configuration to work, you will need to listen for TLS +connections on port 8448. The preferred way to do that is by using a +reverse proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions +on how to correctly set one up. + +In some cases you might not want to run Synapse on the machine that has +the `server_name` as its public DNS hostname, or you might want federation +traffic to use a different port than 8448. For example, you might want to +have your user names look like `@user:example.com`, but you want to run +Synapse on `synapse.example.com` on port 443. This can be done using +delegation, which allows an admin to control where federation traffic should +be sent. See [delegate.md](delegate.md) for instructions on how to set this up. Once federation has been configured, you should be able to join a room over -federation. A good place to start is ``#synapse:matrix.org`` - a room for +federation. A good place to start is `#synapse:matrix.org` - a room for Synapse admins. - -## Delegation - -For a more flexible configuration, you can have ``server_name`` -resources (eg: ``@user:example.com``) served by a different host and -port (eg: ``synapse.example.com:443``). There are two ways to do this: - -- adding a ``/.well-known/matrix/server`` URL served on ``https://example.com``. -- adding a DNS ``SRV`` record in the DNS zone of domain - ``example.com``. - -Without configuring delegation, the matrix federation will -expect to find your server via ``example.com:8448``. The following methods -allow you retain a `server_name` of `example.com` so that your user IDs, room -aliases, etc continue to look like `*:example.com`, whilst having your -federation traffic routed to a different server. - -### .well-known delegation - -To use this method, you need to be able to alter the -``server_name`` 's https server to serve the ``/.well-known/matrix/server`` -URL. Having an active server (with a valid TLS certificate) serving your -``server_name`` domain is out of the scope of this documentation. - -The URL ``https://<server_name>/.well-known/matrix/server`` should -return a JSON structure containing the key ``m.server`` like so: - - { - "m.server": "<synapse.server.name>[:<yourport>]" - } - -In our example, this would mean that URL ``https://example.com/.well-known/matrix/server`` -should return: - - { - "m.server": "synapse.example.com:443" - } - -Note, specifying a port is optional. If a port is not specified an SRV lookup -is performed, as described below. If the target of the -delegation does not have an SRV record, then the port defaults to 8448. - -Most installations will not need to configure .well-known. However, it can be -useful in cases where the admin is hosting on behalf of someone else and -therefore cannot gain access to the necessary certificate. With .well-known, -federation servers will check for a valid TLS certificate for the delegated -hostname (in our example: ``synapse.example.com``). - -.well-known support first appeared in Synapse v0.99.0. To federate with older -servers you may need to additionally configure SRV delegation. Alternatively, -encourage the server admin in question to upgrade :). - -### DNS SRV delegation - -To use this delegation method, you need to have write access to your -``server_name`` 's domain zone DNS records (in our example it would be -``example.com`` DNS zone). - -This method requires the target server to provide a -valid TLS certificate for the original ``server_name``. - -You need to add a SRV record in your ``server_name`` 's DNS zone with -this format: - - _matrix._tcp.<yourdomain.com> <ttl> IN SRV <priority> <weight> <port> <synapse.server.name> - -In our example, we would need to add this SRV record in the -``example.com`` DNS zone: - - _matrix._tcp.example.com. 3600 IN SRV 10 5 443 synapse.example.com. - -Once done and set up, you can check the DNS record with ``dig -t srv -_matrix._tcp.<server_name>``. In our example, we would expect this: - - $ dig -t srv _matrix._tcp.example.com - _matrix._tcp.example.com. 3600 IN SRV 10 0 443 synapse.example.com. - -Note that the target of a SRV record cannot be an alias (CNAME record): it has to point -directly to the server hosting the synapse instance. - -### Delegation FAQ -#### When do I need a SRV record or .well-known URI? - -If your homeserver listens on the default federation port (8448), and your -`server_name` points to the host that your homeserver runs on, you do not need an SRV -record or `.well-known/matrix/server` URI. - -For instance, if you registered `example.com` and pointed its DNS A record at a -fresh server, you could install Synapse on that host, -giving it a `server_name` of `example.com`, and once [ACME](acme.md) support is enabled, -it would automatically generate a valid TLS certificate for you via Let's Encrypt -and no SRV record or .well-known URI would be needed. - -This is the common case, although you can add an SRV record or -`.well-known/matrix/server` URI for completeness if you wish. - -**However**, if your server does not listen on port 8448, or if your `server_name` -does not point to the host that your homeserver runs on, you will need to let -other servers know how to find it. The way to do this is via .well-known or an -SRV record. - -#### I have created a .well-known URI. Do I still need an SRV record? - -As of Synapse 0.99, Synapse will first check for the existence of a .well-known -URI and follow any delegation it suggests. It will only then check for the -existence of an SRV record. - -That means that the SRV record will often be redundant. However, you should -remember that there may still be older versions of Synapse in the federation -which do not understand .well-known URIs, so if you removed your SRV record -you would no longer be able to federate with them. - -It is therefore best to leave the SRV record in place for now. Synapse 0.34 and -earlier will follow the SRV record (and not care about the invalid -certificate). Synapse 0.99 and later will follow the .well-known URI, with the -correct certificate chain. - -#### Can I manage my own certificates rather than having Synapse renew certificates itself? - -Yes, you are welcome to manage your certificates yourself. Synapse will only -attempt to obtain certificates from Let's Encrypt if you configure it to do -so.The only requirement is that there is a valid TLS cert present for -federation end points. - -#### Do you still recommend against using a reverse proxy on the federation port? - -We no longer actively recommend against using a reverse proxy. Many admins will -find it easier to direct federation traffic to a reverse proxy and manage their -own TLS certificates, and this is a supported configuration. - -See [reverse_proxy.md](reverse_proxy.md) for information on setting up a -reverse proxy. - -#### Do I still need to give my TLS certificates to Synapse if I am using a reverse proxy? - -Practically speaking, this is no longer necessary. - -If you are using a reverse proxy for all of your TLS traffic, then you can set -`no_tls: True` in the Synapse config. In that case, the only reason Synapse -needs the certificate is to populate a legacy `tls_fingerprints` field in the -federation API. This is ignored by Synapse 0.99.0 and later, and the only time -pre-0.99 Synapses will check it is when attempting to fetch the server keys - -and generally this is delegated via `matrix.org`, which will be running a modern -version of Synapse. - -#### Do I need the same certificate for the client and federation port? - -No. There is nothing stopping you from using different certificates, -particularly if you are using a reverse proxy. However, Synapse will use the -same certificate on any ports where TLS is configured. - ## Troubleshooting -You can use the [federation tester]( -<https://matrix.org/federationtester>) to check if your homeserver is -configured correctly. Alternatively try the [JSON API used by the federation tester](https://matrix.org/federationtester/api/report?server_name=DOMAIN). -Note that you'll have to modify this URL to replace ``DOMAIN`` with your -``server_name``. Hitting the API directly provides extra detail. +You can use the [federation tester](https://matrix.org/federationtester) +to check if your homeserver is configured correctly. Alternatively try the +[JSON API used by the federation tester](https://matrix.org/federationtester/api/report?server_name=DOMAIN). +Note that you'll have to modify this URL to replace `DOMAIN` with your +`server_name`. Hitting the API directly provides extra detail. The typical failure mode for federation is that when the server tries to join a room, it is rejected with "401: Unauthorized". Generally this means that other @@ -187,8 +47,8 @@ you invite them to. This can be caused by an incorrectly-configured reverse proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly configure a reverse proxy. -## Running a Demo Federation of Synapses +## Running a demo federation of Synapses If you want to get up and running quickly with a trio of homeservers in a -private federation, there is a script in the ``demo`` directory. This is mainly +private federation, there is a script in the `demo` directory. This is mainly useful just for development purposes. See [demo/README](<../demo/README>). diff --git a/docs/log_contexts.md b/docs/log_contexts.md index 5331e8c88b..fe30ca2791 100644 --- a/docs/log_contexts.md +++ b/docs/log_contexts.md @@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets def handle_request(request_id): request_context = context.LoggingContext() - calling_context = context.LoggingContext.current_context() - context.LoggingContext.set_current_context(request_context) + calling_context = context.set_current_context(request_context) try: request_context.request = request_id do_request_handling() logger.debug("finished") finally: - context.LoggingContext.set_current_context(calling_context) + context.set_current_context(calling_context) def do_request_handling(): logger.debug("phew") # this will be logged against request_id diff --git a/docs/message_retention_policies.md b/docs/message_retention_policies.md new file mode 100644 index 0000000000..1dd60bdad9 --- /dev/null +++ b/docs/message_retention_policies.md @@ -0,0 +1,195 @@ +# Message retention policies + +Synapse admins can enable support for message retention policies on +their homeserver. Message retention policies exist at a room level, +follow the semantics described in +[MSC1763](https://github.com/matrix-org/matrix-doc/blob/matthew/msc1763/proposals/1763-configurable-retention-periods.md), +and allow server and room admins to configure how long messages should +be kept in a homeserver's database before being purged from it. +**Please note that, as this feature isn't part of the Matrix +specification yet, this implementation is to be considered as +experimental.** + +A message retention policy is mainly defined by its `max_lifetime` +parameter, which defines how long a message can be kept around after +it was sent to the room. If a room doesn't have a message retention +policy, and there's no default one for a given server, then no message +sent in that room is ever purged on that server. + +MSC1763 also specifies semantics for a `min_lifetime` parameter which +defines the amount of time after which an event _can_ get purged (after +it was sent to the room), but Synapse doesn't currently support it +beyond registering it. + +Both `max_lifetime` and `min_lifetime` are optional parameters. + +Note that message retention policies don't apply to state events. + +Once an event reaches its expiry date (defined as the time it was sent +plus the value for `max_lifetime` in the room), two things happen: + +* Synapse stops serving the event to clients via any endpoint. +* The message gets picked up by the next purge job (see the "Purge jobs" + section) and is removed from Synapse's database. + +Since purge jobs don't run continuously, this means that an event might +stay in a server's database for longer than the value for `max_lifetime` +in the room would allow, though hidden from clients. + +Similarly, if a server (with support for message retention policies +enabled) receives from another server an event that should have been +purged according to its room's policy, then the receiving server will +process and store that event until it's picked up by the next purge job, +though it will always hide it from clients. + +Synapse requires at least one message in each room, so it will never +delete the last message in a room. It will, however, hide it from +clients. + + +## Server configuration + +Support for this feature can be enabled and configured in the +`retention` section of the Synapse configuration file (see the +[sample file](https://github.com/matrix-org/synapse/blob/v1.7.3/docs/sample_config.yaml#L332-L393)). + +To enable support for message retention policies, set the setting +`enabled` in this section to `true`. + + +### Default policy + +A default message retention policy is a policy defined in Synapse's +configuration that is used by Synapse for every room that doesn't have a +message retention policy configured in its state. This allows server +admins to ensure that messages are never kept indefinitely in a server's +database. + +A default policy can be defined as such, in the `retention` section of +the configuration file: + +```yaml + default_policy: + min_lifetime: 1d + max_lifetime: 1y +``` + +Here, `min_lifetime` and `max_lifetime` have the same meaning and level +of support as previously described. They can be expressed either as a +duration (using the units `s` (seconds), `m` (minutes), `h` (hours), +`d` (days), `w` (weeks) and `y` (years)) or as a number of milliseconds. + + +### Purge jobs + +Purge jobs are the jobs that Synapse runs in the background to purge +expired events from the database. They are only run if support for +message retention policies is enabled in the server's configuration. If +no configuration for purge jobs is configured by the server admin, +Synapse will use a default configuration, which is described in the +[sample configuration file](https://github.com/matrix-org/synapse/blob/master/docs/sample_config.yaml#L332-L393). + +Some server admins might want a finer control on when events are removed +depending on an event's room's policy. This can be done by setting the +`purge_jobs` sub-section in the `retention` section of the configuration +file. An example of such configuration could be: + +```yaml + purge_jobs: + - longest_max_lifetime: 3d + interval: 12h + - shortest_max_lifetime: 3d + longest_max_lifetime: 1w + interval: 1d + - shortest_max_lifetime: 1w + interval: 2d +``` + +In this example, we define three jobs: + +* one that runs twice a day (every 12 hours) and purges events in rooms + which policy's `max_lifetime` is lower or equal to 3 days. +* one that runs once a day and purges events in rooms which policy's + `max_lifetime` is between 3 days and a week. +* one that runs once every 2 days and purges events in rooms which + policy's `max_lifetime` is greater than a week. + +Note that this example is tailored to show different configurations and +features slightly more jobs than it's probably necessary (in practice, a +server admin would probably consider it better to replace the two last +jobs with one that runs once a day and handles rooms which which +policy's `max_lifetime` is greater than 3 days). + +Keep in mind, when configuring these jobs, that a purge job can become +quite heavy on the server if it targets many rooms, therefore prefer +having jobs with a low interval that target a limited set of rooms. Also +make sure to include a job with no minimum and one with no maximum to +make sure your configuration handles every policy. + +As previously mentioned in this documentation, while a purge job that +runs e.g. every day means that an expired event might stay in the +database for up to a day after its expiry, Synapse hides expired events +from clients as soon as they expire, so the event is not visible to +local users between its expiry date and the moment it gets purged from +the server's database. + + +### Lifetime limits + +**Note: this feature is mainly useful within a closed federation or on +servers that don't federate, because there currently is no way to +enforce these limits in an open federation.** + +Server admins can restrict the values their local users are allowed to +use for both `min_lifetime` and `max_lifetime`. These limits can be +defined as such in the `retention` section of the configuration file: + +```yaml + allowed_lifetime_min: 1d + allowed_lifetime_max: 1y +``` + +Here, `allowed_lifetime_min` is the lowest value a local user can set +for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max` +is the highest value. Both parameters are optional (e.g. setting +`allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a +minimum and no maximum). + +Like other settings in this section, these parameters can be expressed +either as a duration or as a number of milliseconds. + + +## Room configuration + +To configure a room's message retention policy, a room's admin or +moderator needs to send a state event in that room with the type +`m.room.retention` and the following content: + +```json +{ + "max_lifetime": ... +} +``` + +In this event's content, the `max_lifetime` parameter has the same +meaning as previously described, and needs to be expressed in +milliseconds. The event's content can also include a `min_lifetime` +parameter, which has the same meaning and limited support as previously +described. + +Note that over every server in the room, only the ones with support for +message retention policies will actually remove expired events. This +support is currently not enabled by default in Synapse. + + +## Note on reclaiming disk space + +While purge jobs actually delete data from the database, the disk space +used by the database might not decrease immediately on the database's +host. However, even though the database engine won't free up the disk +space, it will start writing new data into where the purged data was. + +If you want to reclaim the freed disk space anyway and return it to the +operating system, the server admin needs to run `VACUUM FULL;` (or +`VACUUM;` for SQLite databases) on Synapse's database (see the related +[PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-vacuum.html)). diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md index 32abb9f44e..cf69938a2a 100644 --- a/docs/metrics-howto.md +++ b/docs/metrics-howto.md @@ -60,6 +60,31 @@ 1. Restart Prometheus. +## Monitoring workers + +To monitor a Synapse installation using +[workers](https://github.com/matrix-org/synapse/blob/master/docs/workers.md), +every worker needs to be monitored independently, in addition to +the main homeserver process. This is because workers don't send +their metrics to the main homeserver process, but expose them +directly (if they are configured to do so). + +To allow collecting metrics from a worker, you need to add a +`metrics` listener to its configuration, by adding the following +under `worker_listeners`: + +```yaml + - type: metrics + bind_address: '' + port: 9101 +``` + +The `bind_address` and `port` parameters should be set so that +the resulting listener can be reached by prometheus, and they +don't clash with an existing worker. +With this example, the worker's metrics would then be available +on `http://127.0.0.1:9101`. + ## Renaming of metrics & deprecation of old names in 1.2 Synapse 1.2 updates the Prometheus metrics to match the naming diff --git a/docs/openid.md b/docs/openid.md new file mode 100644 index 0000000000..688379ddd9 --- /dev/null +++ b/docs/openid.md @@ -0,0 +1,206 @@ +# Configuring Synapse to authenticate against an OpenID Connect provider + +Synapse can be configured to use an OpenID Connect Provider (OP) for +authentication, instead of its own local password database. + +Any OP should work with Synapse, as long as it supports the authorization code +flow. There are a few options for that: + + - start a local OP. Synapse has been tested with [Hydra][hydra] and + [Dex][dex-idp]. Note that for an OP to work, it should be served under a + secure (HTTPS) origin. A certificate signed with a self-signed, locally + trusted CA should work. In that case, start Synapse with a `SSL_CERT_FILE` + environment variable set to the path of the CA. + + - set up a SaaS OP, like [Google][google-idp], [Auth0][auth0] or + [Okta][okta]. Synapse has been tested with Auth0 and Google. + +It may also be possible to use other OAuth2 providers which provide the +[authorization code grant type](https://tools.ietf.org/html/rfc6749#section-4.1), +such as [Github][github-idp]. + +[google-idp]: https://developers.google.com/identity/protocols/oauth2/openid-connect +[auth0]: https://auth0.com/ +[okta]: https://www.okta.com/ +[dex-idp]: https://github.com/dexidp/dex +[hydra]: https://www.ory.sh/docs/hydra/ +[github-idp]: https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps + +## Preparing Synapse + +The OpenID integration in Synapse uses the +[`authlib`](https://pypi.org/project/Authlib/) library, which must be installed +as follows: + + * The relevant libraries are included in the Docker images and Debian packages + provided by `matrix.org` so no further action is needed. + + * If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip + install synapse[oidc]` to install the necessary dependencies. + + * For other installation mechanisms, see the documentation provided by the + maintainer. + +To enable the OpenID integration, you should then add an `oidc_config` section +to your configuration file (or uncomment the `enabled: true` line in the +existing section). See [sample_config.yaml](./sample_config.yaml) for some +sample settings, as well as the text below for example configurations for +specific providers. + +## Sample configs + +Here are a few configs for providers that should work with Synapse. + +### [Dex][dex-idp] + +[Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. +Although it is designed to help building a full-blown provider with an +external database, it can be configured with static passwords in a config file. + +Follow the [Getting Started +guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md) +to install Dex. + +Edit `examples/config-dev.yaml` config file from the Dex repo to add a client: + +```yaml +staticClients: +- id: synapse + secret: secret + redirectURIs: + - '[synapse public baseurl]/_synapse/oidc/callback' + name: 'Synapse' +``` + +Run with `dex serve examples/config-dex.yaml`. + +Synapse config: + +```yaml +oidc_config: + enabled: true + skip_verification: true # This is needed as Dex is served on an insecure endpoint + issuer: "http://127.0.0.1:5556/dex" + client_id: "synapse" + client_secret: "secret" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.name }}" + display_name_template: "{{ user.name|capitalize }}" +``` + +### [Auth0][auth0] + +1. Create a regular web application for Synapse +2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback` +3. Add a rule to add the `preferred_username` claim. + <details> + <summary>Code sample</summary> + + ```js + function addPersistenceAttribute(user, context, callback) { + user.user_metadata = user.user_metadata || {}; + user.user_metadata.preferred_username = user.user_metadata.preferred_username || user.user_id; + context.idToken.preferred_username = user.user_metadata.preferred_username; + + auth0.users.updateUserMetadata(user.user_id, user.user_metadata) + .then(function(){ + callback(null, user, context); + }) + .catch(function(err){ + callback(err); + }); + } + ``` + </details> + +Synapse config: + +```yaml +oidc_config: + enabled: true + issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" +``` + +### GitHub + +GitHub is a bit special as it is not an OpenID Connect compliant provider, but +just a regular OAuth2 provider. + +The [`/user` API endpoint](https://developer.github.com/v3/users/#get-the-authenticated-user) +can be used to retrieve information on the authenticated user. As the Synaspse +login mechanism needs an attribute to uniquely identify users, and that endpoint +does not return a `sub` property, an alternative `subject_claim` has to be set. + +1. Create a new OAuth application: https://github.com/settings/applications/new. +2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`. + +Synapse config: + +```yaml +oidc_config: + enabled: true + discover: false + issuer: "https://github.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + authorization_endpoint: "https://github.com/login/oauth/authorize" + token_endpoint: "https://github.com/login/oauth/access_token" + userinfo_endpoint: "https://api.github.com/user" + scopes: ["read:user"] + user_mapping_provider: + config: + subject_claim: "id" + localpart_template: "{{ user.login }}" + display_name_template: "{{ user.name }}" +``` + +### [Google][google-idp] + +1. Set up a project in the Google API Console (see + https://developers.google.com/identity/protocols/oauth2/openid-connect#appsetup). +2. add an "OAuth Client ID" for a Web Application under "Credentials". +3. Copy the Client ID and Client Secret, and add the following to your synapse config: + ```yaml + oidc_config: + enabled: true + issuer: "https://accounts.google.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.given_name|lower }}" + display_name_template: "{{ user.name }}" + ``` +4. Back in the Google console, add this Authorized redirect URI: `[synapse + public baseurl]/_synapse/oidc/callback`. + +### Twitch + +1. Setup a developer account on [Twitch](https://dev.twitch.tv/) +2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/) +3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback` + +Synapse config: + +```yaml +oidc_config: + enabled: true + issuer: "https://id.twitch.tv/oauth2/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + user_mapping_provider: + config: + localpart_template: '{{ user.preferred_username }}' + display_name_template: '{{ user.name }}' +``` diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index 0db1a3804a..5d9ae67041 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -9,7 +9,11 @@ into Synapse, and provides a number of methods by which it can integrate with the authentication system. This document serves as a reference for those looking to implement their -own password auth providers. +own password auth providers. Additionally, here is a list of known +password auth provider module implementations: + +* [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/) +* [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth) ## Required methods diff --git a/docs/postgres.md b/docs/postgres.md index 29cf762858..70fe29cdcc 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -27,17 +27,21 @@ connect to a postgres database. ## Set up database -Assuming your PostgreSQL database user is called `postgres`, create a -user `synapse_user` with: +Assuming your PostgreSQL database user is called `postgres`, first authenticate as the database user with: su - postgres + # Or, if your system uses sudo to get administrative rights + sudo -u postgres bash + +Then, create a user ``synapse_user`` with: + createuser --pwprompt synapse_user Before you can authenticate with the `synapse_user`, you must create a database that it can access. To create a database, first connect to the database with your database user: - su - postgres + su - postgres # Or: sudo -u postgres bash psql and then run: @@ -57,7 +61,50 @@ Note that the PostgreSQL database *must* have the correct encoding set You may need to enable password authentication so `synapse_user` can connect to the database. See -<https://www.postgresql.org/docs/11/auth-pg-hba-conf.html>. +<https://www.postgresql.org/docs/current/auth-pg-hba-conf.html>. + +If you get an error along the lines of `FATAL: Ident authentication failed for +user "synapse_user"`, you may need to use an authentication method other than +`ident`: + +* If the `synapse_user` user has a password, add the password to the `database:` + section of `homeserver.yaml`. Then add the following to `pg_hba.conf`: + + ``` + host synapse synapse_user ::1/128 md5 # or `scram-sha-256` instead of `md5` if you use that + ``` + +* If the `synapse_user` user does not have a password, then a password doesn't + have to be added to `homeserver.yaml`. But the following does need to be added + to `pg_hba.conf`: + + ``` + host synapse synapse_user ::1/128 trust + ``` + +Note that line order matters in `pg_hba.conf`, so make sure that if you do add a +new line, it is inserted before: + +``` +host all all ::1/128 ident +``` + +### Fixing incorrect `COLLATE` or `CTYPE` + +Synapse will refuse to set up a new database if it has the wrong values of +`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using +different locales can cause issues if the locale library is updated from +underneath the database, or if a different version of the locale is used on any +replicas. + +The safest way to fix the issue is to take a dump and recreate the database with +the correct `COLLATE` and `CTYPE` parameters (as shown above). It is also possible to change the +parameters on a live database and run a `REINDEX` on the entire database, +however extreme care must be taken to avoid database corruption. + +Note that the above may fail with an error about duplicate rows if corruption +has already occurred, and such duplicate rows will need to be manually removed. + ## Tuning Postgres @@ -83,19 +130,41 @@ of free memory the database host has available. When you are ready to start using PostgreSQL, edit the `database` section in your config file to match the following lines: - database: - name: psycopg2 - args: - user: <user> - password: <pass> - database: <db> - host: <host> - cp_min: 5 - cp_max: 10 +```yaml +database: + name: psycopg2 + args: + user: <user> + password: <pass> + database: <db> + host: <host> + cp_min: 5 + cp_max: 10 +``` All key, values in `args` are passed to the `psycopg2.connect(..)` function, except keys beginning with `cp_`, which are consumed by the -twisted adbapi connection pool. +twisted adbapi connection pool. See the [libpq +documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS) +for a list of options which can be passed. + +You should consider tuning the `args.keepalives_*` options if there is any danger of +the connection between your homeserver and database dropping, otherwise Synapse +may block for an extended period while it waits for a response from the +database server. Example values might be: + +```yaml +# seconds of inactivity after which TCP should send a keepalive message to the server +keepalives_idle: 10 + +# the number of seconds after which a TCP keepalive message that is not +# acknowledged by the server should be retransmitted +keepalives_interval: 10 + +# the number of TCP keepalives that can be lost before the client's connection +# to the server is considered dead +keepalives_count: 3 +``` ## Porting from SQLite diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index dcfc5c64aa..cbb8269568 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -9,7 +9,7 @@ of doing so is that it means that you can expose the default https port (443) to Matrix clients without needing to run Synapse with root privileges. -> **NOTE**: Your reverse proxy must not `canonicalise` or `normalise` +**NOTE**: Your reverse proxy must not `canonicalise` or `normalise` the requested URI in any way (for example, by decoding `%xx` escapes). Beware that Apache *will* canonicalise URIs unless you specifify `nocanon`. @@ -18,99 +18,123 @@ When setting up a reverse proxy, remember that Matrix clients and other Matrix servers do not necessarily need to connect to your server via the same server name or port. Indeed, clients will use port 443 by default, whereas servers default to port 8448. Where these are different, we -refer to the 'client port' and the \'federation port\'. See [Setting -up federation](federate.md) for more details of the algorithm used for -federation connections. +refer to the 'client port' and the 'federation port'. See [the Matrix +specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names) +for more details of the algorithm used for federation connections, and +[delegate.md](<delegate.md>) for instructions on setting up delegation. Let's assume that we expect clients to connect to our server at `https://matrix.example.com`, and other servers to connect at `https://example.com:8448`. The following sections detail the configuration of the reverse proxy and the homeserver. -## Webserver configuration examples +## Reverse-proxy configuration examples -> **NOTE**: You only need one of these. +**NOTE**: You only need one of these. ### nginx - server { - listen 443 ssl; - listen [::]:443 ssl; - server_name matrix.example.com; - - location /_matrix { - proxy_pass http://localhost:8008; - proxy_set_header X-Forwarded-For $remote_addr; - } - } - - server { - listen 8448 ssl default_server; - listen [::]:8448 ssl default_server; - server_name example.com; - - location / { - proxy_pass http://localhost:8008; - proxy_set_header X-Forwarded-For $remote_addr; - } - } - -> **NOTE**: Do not add a `/` after the port in `proxy_pass`, otherwise nginx will +``` +server { + listen 443 ssl; + listen [::]:443 ssl; + server_name matrix.example.com; + + location /_matrix { + proxy_pass http://localhost:8008; + proxy_set_header X-Forwarded-For $remote_addr; + # Nginx by default only allows file uploads up to 1M in size + # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml + client_max_body_size 10M; + } +} + +server { + listen 8448 ssl default_server; + listen [::]:8448 ssl default_server; + server_name example.com; + + location / { + proxy_pass http://localhost:8008; + proxy_set_header X-Forwarded-For $remote_addr; + } +} +``` + +**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will canonicalise/normalise the URI. -### Caddy +### Caddy 1 - matrix.example.com { - proxy /_matrix http://localhost:8008 { - transparent - } - } +``` +matrix.example.com { + proxy /_matrix http://localhost:8008 { + transparent + } +} - example.com:8448 { - proxy / http://localhost:8008 { - transparent - } - } +example.com:8448 { + proxy / http://localhost:8008 { + transparent + } +} +``` + +### Caddy 2 + +``` +matrix.example.com { + reverse_proxy /_matrix/* http://localhost:8008 +} + +example.com:8448 { + reverse_proxy http://localhost:8008 +} +``` ### Apache - <VirtualHost *:443> - SSLEngine on - ServerName matrix.example.com; +``` +<VirtualHost *:443> + SSLEngine on + ServerName matrix.example.com; - AllowEncodedSlashes NoDecode - ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon - ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix - </VirtualHost> + AllowEncodedSlashes NoDecode + ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon + ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix +</VirtualHost> - <VirtualHost *:8448> - SSLEngine on - ServerName example.com; +<VirtualHost *:8448> + SSLEngine on + ServerName example.com; - AllowEncodedSlashes NoDecode - ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon - ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix - </VirtualHost> + AllowEncodedSlashes NoDecode + ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon + ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix +</VirtualHost> +``` -> **NOTE**: ensure the `nocanon` options are included. +**NOTE**: ensure the `nocanon` options are included. ### HAProxy - frontend https - bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1 +``` +frontend https + bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1 - # Matrix client traffic - acl matrix-host hdr(host) -i matrix.example.com - acl matrix-path path_beg /_matrix + # Matrix client traffic + acl matrix-host hdr(host) -i matrix.example.com + acl matrix-path path_beg /_matrix - use_backend matrix if matrix-host matrix-path + use_backend matrix if matrix-host matrix-path - frontend matrix-federation - bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1 - default_backend matrix +frontend matrix-federation + bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1 + default_backend matrix - backend matrix - server matrix 127.0.0.1:8008 +backend matrix + server matrix 127.0.0.1:8008 +``` ## Homeserver Configuration diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3e4edc6b0b..94e1ec698f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1,4 +1,4 @@ -# The config is maintained as an up-to-date snapshot of the default +# This file is maintained as an up-to-date snapshot of the default # homeserver.yaml configuration generated by Synapse. # # It is intended to act as a reference for the default configuration, @@ -10,6 +10,16 @@ # homeserver.yaml. Instead, if you are starting from scratch, please generate # a fresh config using Synapse by following the instructions in INSTALL.md. +################################################################################ + +# Configuration file for Synapse. +# +# This is a YAML file: see [1] for a quick introduction. Note in particular +# that *indentation is important*: all the elements of a list or dictionary +# should have the same indentation. +# +# [1] https://docs.ansible.com/ansible/latest/reference_appendices/YAMLSyntax.html + ## Server ## # The domain name of the server, with optional explicit port. @@ -23,10 +33,15 @@ server_name: "SERVERNAME" # pid_file: DATADIR/homeserver.pid -# The path to the web client which will be served at /_matrix/client/ -# if 'webclient' is configured under the 'listeners' configuration. +# The absolute URL to the web client which /_matrix/client will redirect +# to if 'webclient' is configured under the 'listeners' configuration. # -#web_client_location: "/path/to/web/root" +# This option can be also set to the filesystem path to the web client +# which will be served at /_matrix/client/ if 'webclient' is configured +# under the 'listeners' configuration, however this is a security risk: +# https://github.com/matrix-org/synapse#security-note +# +#web_client_location: https://riot.example.com/ # The public-facing base URL that clients use to access this HS # (not including _matrix/...). This is the same URL a user would @@ -54,15 +69,23 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true -# If set to 'false', requires authentication to access the server's public rooms -# directory through the client API. Defaults to 'true'. +# Uncomment to require a user to share a room with another user in order +# to retrieve their profile information. Only checked on Client-Server +# requests. Profile requests from other servers should be checked by the +# requesting server. Defaults to 'false'. +# +#limit_profile_requests_to_users_who_share_rooms: true + +# If set to 'true', removes the need for authentication to access the server's +# public rooms directory through the client API, meaning that anyone can +# query the room directory. Defaults to 'false'. # -#allow_public_rooms_without_auth: false +#allow_public_rooms_without_auth: true -# If set to 'false', forbids any other homeserver to fetch the server's public -# rooms directory via federation. Defaults to 'true'. +# If set to 'true', allows any other homeserver to fetch the server's public +# rooms directory via federation. Defaults to 'false'. # -#allow_public_rooms_over_federation: false +#allow_public_rooms_over_federation: true # The default room version for newly created rooms. # @@ -72,7 +95,7 @@ pid_file: DATADIR/homeserver.pid # For example, for room version 1, default_room_version should be set # to "1". # -#default_room_version: "4" +#default_room_version: "5" # The GC threshold parameters to pass to `gc.set_threshold`, if defined # @@ -86,7 +109,7 @@ pid_file: DATADIR/homeserver.pid # Whether room invites to users on this server should be blocked # (except those sent by local server admins). The default is False. # -#block_non_admin_invites: True +#block_non_admin_invites: true # Room searching # @@ -110,6 +133,9 @@ pid_file: DATADIR/homeserver.pid # blacklist IP address CIDR ranges. If this option is not specified, or # specified with an empty list, no ip range blacklist will be enforced. # +# As of Synapse v1.4.0 this option also affects any outbound requests to identity +# servers provided by user input. +# # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly # listed here, since they correspond to unroutable addresses.) # @@ -227,6 +253,18 @@ listeners: # bind_addresses: ['::1', '127.0.0.1'] # type: manhole +# Forward extremities can build up in a room due to networking delays between +# homeservers. Once this happens in a large room, calculation of the state of +# that room can become quite expensive. To mitigate this, once the number of +# forward extremities reaches a given threshold, Synapse will send an +# org.matrix.dummy_event event, which will reduce the forward extremities +# in the room. +# +# This setting defines the threshold (i.e. number of forward extremities in the +# room) at which dummy events are sent. The default value is 10. +# +#dummy_events_threshold: 5 + ## Homeserver blocking ## @@ -236,9 +274,8 @@ listeners: # Global blocking # -#hs_disabled: False +#hs_disabled: false #hs_disabled_message: 'Human readable reason for why the HS is blocked' -#hs_disabled_limit_type: 'error code(str), to help clients decode reason' # Monthly Active User Blocking # @@ -258,15 +295,22 @@ listeners: # sign up in a short space of time never to return after their initial # session. # -#limit_usage_by_mau: False +# 'mau_limit_alerting' is a means of limiting client side alerting +# should the mau limit be reached. This is useful for small instances +# where the admin has 5 mau seats (say) for 5 specific people and no +# interest increasing the mau limit further. Defaults to True, which +# means that alerting is enabled +# +#limit_usage_by_mau: false #max_mau_value: 50 #mau_trial_days: 2 +#mau_limit_alerting: false # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau # is true, this is implied to be true. # -#mau_stats_only: False +#mau_stats_only: false # Sometimes the server admin will want to ensure certain accounts are # never blocked by mau checking. These accounts are specified here. @@ -278,22 +322,27 @@ listeners: # Used by phonehome stats to group together related servers. #server_context: context -# Resource-constrained Homeserver Settings +# Resource-constrained homeserver settings # -# If limit_remote_rooms.enabled is True, the room complexity will be -# checked before a user joins a new remote room. If it is above -# limit_remote_rooms.complexity, it will disallow joining or -# instantly leave. +# When this is enabled, the room "complexity" will be checked before a user +# joins a new remote room. If it is above the complexity limit, the server will +# disallow joining, or will instantly leave. # -# limit_remote_rooms.complexity_error can be set to customise the text -# displayed to the user when a room above the complexity threshold has -# its join cancelled. +# Room complexity is an arbitrary measure based on factors such as the number of +# users in the room. # -# Uncomment the below lines to enable: -#limit_remote_rooms: -# enabled: True -# complexity: 1.0 -# complexity_error: "This room is too complex." +limit_remote_rooms: + # Uncomment to enable room complexity checking. + # + #enabled: true + + # the limit above which rooms cannot be joined. The default is 1.0. + # + #complexity: 0.5 + + # override the error which is returned when the room is too complex. + # + #complexity_error: "This room is too complex." # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. @@ -311,7 +360,86 @@ listeners: # # Defaults to `7d`. Set to `null` to disable. # -redaction_retention_period: 7d +#redaction_retention_period: 28d + +# How long to track users' last seen time and IPs in the database. +# +# Defaults to `28d`. Set to `null` to disable clearing out of old rows. +# +#user_ips_max_age: 14d + +# Message retention policy at the server level. +# +# Room admins and mods can define a retention period for their rooms using the +# 'm.room.retention' state event, and server admins can cap this period by setting +# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. +# +# If this feature is enabled, Synapse will regularly look for and purge events +# which are older than the room's maximum retention period. Synapse will also +# filter events received over federation so that events that should have been +# purged are ignored and not stored again. +# +retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a more frequent basis than for the rest of the rooms + # (e.g. every 12h), but not want that purge to be performed by a job that's + # iterating over every room it knows, which could be heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 12h + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 1d + +# Inhibits the /requestToken endpoints from returning an error that might leak +# information about whether an e-mail address is in use or not on this +# homeserver. +# Note that for some endpoints the error situation is the e-mail already being +# used, and for others the error is entering the e-mail being unused. +# If this option is enabled, instead of returning an error, these endpoints will +# act as if no error happened and return a fake session ID ('sid') to clients. +# +#request_token_inhibit_3pid_errors: true ## TLS ## @@ -380,6 +508,11 @@ redaction_retention_period: 7d # ACME support: This will configure Synapse to request a valid TLS certificate # for your configured `server_name` via Let's Encrypt. # +# Note that ACME v1 is now deprecated, and Synapse currently doesn't support +# ACME v2. This means that this feature currently won't work with installs set +# up after November 2019. For more info, and alternative solutions, see +# https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1 +# # Note that provisioning a certificate in this way requires port 80 to be # routed to Synapse so that it can complete the http-01 ACME challenge. # By default, if you enable ACME support, Synapse will attempt to listen on @@ -402,7 +535,7 @@ acme: # ACME support is disabled by default. Set this to `true` and uncomment # tls_certificate_path and tls_private_key_path above to enable it. # - enabled: False + enabled: false # Endpoint to use to request certificates. If you only want to test, # use Let's Encrypt's staging url: @@ -475,20 +608,94 @@ acme: -## Database ## +## Caching ## -database: - # The database engine name - name: "sqlite3" - # Arguments to pass to the engine - args: - # Path to the database - database: "DATADIR/homeserver.db" +# Caching can be configured through the following options. +# +# A cache 'factor' is a multiplier that can be applied to each of +# Synapse's caches in order to increase or decrease the maximum +# number of entries that can be stored. -# Number of events to cache in memory. +# The number of events to cache in memory. Not affected by +# caches.global_factor. # #event_cache_size: 10K +caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + # Some caches have '*' and other characters that are not + # alphanumeric or underscores. These caches can be named with or + # without the special characters stripped. For example, to specify + # the cache factor for `*stateGroupCache*` via an environment + # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + +## Database ## + +# The 'database' setting defines the database that synapse uses to store all of +# its data. +# +# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or +# 'psycopg2' (for PostgreSQL). +# +# 'args' gives options which are passed through to the database engine, +# except for options starting 'cp_', which are used to configure the Twisted +# connection pool. For a reference to valid arguments, see: +# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS +# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__ +# +# +# Example SQLite configuration: +# +#database: +# name: sqlite3 +# args: +# database: /path/to/homeserver.db +# +# +# Example Postgres configuration: +# +#database: +# name: psycopg2 +# args: +# user: synapse +# password: secretpassword +# database: synapse +# host: localhost +# cp_min: 5 +# cp_max: 10 +# +# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# +database: + name: sqlite3 + args: + database: DATADIR/homeserver.db + ## Logging ## @@ -596,20 +803,15 @@ media_store_path: "DATADIR/media_store" # #media_storage_providers: # - module: file_system -# # Whether to write new local files. +# # Whether to store newly uploaded local files # store_local: false -# # Whether to write new remote media +# # Whether to store newly downloaded remote files # store_remote: false -# # Whether to block upload requests waiting for write to this -# # provider to complete +# # Whether to wait for successful storage for local uploads # store_synchronous: false # config: # directory: /mnt/some/other/directory -# Directory where in-progress uploads are stored. -# -uploads_path: "DATADIR/uploads" - # The largest allowed upload size in bytes # #max_upload_size: 10M @@ -724,31 +926,55 @@ uploads_path: "DATADIR/uploads" # #max_spider_size: 10M +# A list of values for the Accept-Language HTTP header used when +# downloading webpages during URL preview generation. This allows +# Synapse to specify the preferred languages that URL previews should +# be in when communicating with remote servers. +# +# Each value is a IETF language tag; a 2-3 letter identifier for a +# language, optionally followed by subtags separated by '-', specifying +# a country or region variant. +# +# Multiple values can be provided, and a weight can be added to each by +# using quality value syntax (;q=). '*' translates to any language. +# +# Defaults to "en". +# +# Example: +# +# url_preview_accept_language: +# - en-UK +# - en-US;q=0.9 +# - fr;q=0.8 +# - *;q=0.7 +# +url_preview_accept_language: +# - en + ## Captcha ## -# See docs/CAPTCHA_SETUP for full details of configuring this. +# See docs/CAPTCHA_SETUP.md for full details of configuring this. -# This Home Server's ReCAPTCHA public key. +# This homeserver's ReCAPTCHA public key. Must be specified if +# enable_registration_captcha is enabled. # #recaptcha_public_key: "YOUR_PUBLIC_KEY" -# This Home Server's ReCAPTCHA private key. +# This homeserver's ReCAPTCHA private key. Must be specified if +# enable_registration_captcha is enabled. # #recaptcha_private_key: "YOUR_PRIVATE_KEY" -# Enables ReCaptcha checks when registering, preventing signup +# Uncomment to enable ReCaptcha checks when registering, preventing signup # unless a captcha is answered. Requires a valid ReCaptcha -# public/private key. -# -#enable_registration_captcha: false - -# A secret key used to bypass the captcha test entirely. +# public/private key. Defaults to 'false'. # -#captcha_bypass_secret: "YOUR_SECRET_HERE" +#enable_registration_captcha: true # The API endpoint to use for verifying m.login.recaptcha responses. +# Defaults to "https://www.recaptcha.net/recaptcha/api/siteverify". # -#recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify" +#recaptcha_siteverify_api: "https://my.recaptcha.site" ## TURN ## @@ -777,7 +1003,7 @@ uploads_path: "DATADIR/uploads" # connect to arbitrary endpoints without having first signed up for a # valid account (e.g. by passing a CAPTCHA). # -#turn_allow_guests: True +#turn_allow_guests: true ## Registration ## @@ -792,23 +1018,6 @@ uploads_path: "DATADIR/uploads" # Optional account validity configuration. This allows for accounts to be denied # any request after a given period. # -# ``enabled`` defines whether the account validity feature is enabled. Defaults -# to False. -# -# ``period`` allows setting the period after which an account is valid -# after its registration. When renewing the account, its validity period -# will be extended by this amount of time. This parameter is required when using -# the account validity feature. -# -# ``renew_at`` is the amount of time before an account's expiry date at which -# Synapse will send an email to the account's email address with a renewal link. -# This needs the ``email`` and ``public_baseurl`` configuration sections to be -# filled. -# -# ``renew_email_subject`` is the subject of the email sent out with the renewal -# link. ``%(app)s`` can be used as a placeholder for the ``app_name`` parameter -# from the ``email`` section. -# # Once this feature is enabled, Synapse will look for registered users without an # expiration date at startup and will add one to every account it found using the # current settings at that time. @@ -819,21 +1028,55 @@ uploads_path: "DATADIR/uploads" # date will be randomly selected within a range [now + period - d ; now + period], # where d is equal to 10% of the validity period. # -#account_validity: -# enabled: True -# period: 6w -# renew_at: 1w -# renew_email_subject: "Renew your %(app)s account" -# # Directory in which Synapse will try to find the HTML files to serve to the -# # user when trying to renew an account. Optional, defaults to -# # synapse/res/templates. -# template_dir: "res/templates" -# # HTML to be displayed to the user after they successfully renewed their -# # account. Optional. -# account_renewed_html_path: "account_renewed.html" -# # HTML to be displayed when the user tries to renew an account with an invalid -# # renewal token. Optional. -# invalid_token_html_path: "invalid_token.html" +account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + #template_dir: "res/templates" + + # File within 'template_dir' giving the HTML to be displayed to the user after + # they successfully renewed their account. If not set, default text is used. + # + #account_renewed_html_path: "account_renewed.html" + + # File within 'template_dir' giving the HTML to be displayed when the user + # tries to renew an account with an invalid renewal token. If not set, + # default text is used. + # + #invalid_token_html_path: "invalid_token.html" # Time that a user's session remains valid for, after they log in. # @@ -875,7 +1118,7 @@ uploads_path: "DATADIR/uploads" # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # -# registration_shared_secret: <PRIVATE STRING> +#registration_shared_secret: <PRIVATE STRING> # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. @@ -937,10 +1180,35 @@ uploads_path: "DATADIR/uploads" # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # +# If a delegate is specified, the config option public_baseurl must also be filled out. +# account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process +# Whether users are allowed to change their displayname after it has +# been initially set. Useful when provisioning users based on the +# contents of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_displayname: false + +# Whether users are allowed to change their avatar after it has been +# initially set. Useful when provisioning users based on the contents +# of a third-party directory. +# +# Does not apply to server administrators. Defaults to 'true' +# +#enable_set_avatar_url: false + +# Whether users can change the 3PIDs associated with their accounts +# (email address and msisdn). +# +# Defaults to 'true' +# +#enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # @@ -955,12 +1223,19 @@ account_threepid_delegates: # #autocreate_auto_join_rooms: true +# When auto_join_rooms is specified, setting this flag to false prevents +# guest accounts from being automatically joined to the rooms. +# +# Defaults to true. +# +#auto_join_rooms_for_guests: false + ## Metrics ### # Enable collection and rendering of performance metrics # -#enable_metrics: False +#enable_metrics: false # Enable sentry integration # NOTE: While attempts are made to ensure that the logs don't contain @@ -976,14 +1251,15 @@ account_threepid_delegates: # enabled by default, either for performance reasons or limited use. # metrics_flags: - # Publish synapse_federation_known_servers, a g auge of the number of + # Publish synapse_federation_known_servers, a gauge of the number of # servers this homeserver knows about, including itself. May cause # performance problems on large homeservers. # #known_servers: true # Whether or not to report anonymized homeserver usage statistics. -# report_stats: true|false +# +#report_stats: true|false # The endpoint to report the anonymized homeserver usage statistics to. # Defaults to https://matrix.org/report-usage-stats/push @@ -1012,20 +1288,20 @@ metrics_flags: # Uncomment to enable tracking of application service IP addresses. Implicitly # enables MAU tracking for application service users. # -#track_appservice_user_ips: True +#track_appservice_user_ips: true # a secret which is used to sign access tokens. If none is specified, # the registration_shared_secret is used, if one is given; otherwise, # a secret key is derived from the signing key. # -# macaroon_secret_key: <PRIVATE STRING> +#macaroon_secret_key: <PRIVATE STRING> # a secret which is used to calculate HMACs for form values, to stop # falsification of values. Must be specified for the User Consent # forms to work. # -# form_secret: <PRIVATE STRING> +#form_secret: <PRIVATE STRING> ## Signing Keys ## @@ -1034,14 +1310,19 @@ metrics_flags: signing_key_path: "CONFDIR/SERVERNAME.signing.key" # The keys that the server used to sign messages with but won't use -# to sign new messages. E.g. it has lost its private key +# to sign new messages. # -#old_signing_keys: -# "ed25519:auto": -# # Base64 encoded public key -# key: "The public part of your old signing key." -# # Millisecond POSIX timestamp when the key expired. -# expired_ts: 123456789123 +old_signing_keys: + # For each key, `key` should be the base64-encoded public key, and + # `expired_ts`should be the time (in milliseconds since the unix epoch) that + # it was last used. + # + # It is possible to build an entry from an old signing.key file using the + # `export_signing_key` script which is provided with synapse. + # + # For example: + # + #"ed25519:id": { key: "base64string", expired_ts: 123456789123 } # How long key response published by this server is valid for. # Used to set the valid_until_ts in /key/v2 APIs. @@ -1061,6 +1342,10 @@ signing_key_path: "CONFDIR/SERVERNAME.signing.key" # This setting supercedes an older setting named `perspectives`. The old format # is still supported for backwards-compatibility, but it is deprecated. # +# 'trusted_key_servers' defaults to matrix.org, but using it will generate a +# warning on start-up. To suppress this warning, set +# 'suppress_key_server_warning' to true. +# # Options for each entry in the list include: # # server_name: the name of the server. required. @@ -1085,11 +1370,13 @@ signing_key_path: "CONFDIR/SERVERNAME.signing.key" # "ed25519:auto": "abcdefghijklmnopqrstuvwxyzabcdefghijklmopqr" # - server_name: "my_other_trusted_server.example.com" # -# The default configuration is: -# -#trusted_key_servers: -# - server_name: "matrix.org" +trusted_key_servers: + - server_name: "matrix.org" + +# Uncomment the following to disable the warning that is emitted when the +# trusted_key_servers include 'matrix.org'. See above. # +#suppress_key_server_warning: true # The signing keys to use when acting as a trusted key server. If not specified # defaults to the server signing key. @@ -1099,14 +1386,17 @@ signing_key_path: "CONFDIR/SERVERNAME.signing.key" #key_server_signing_keys_path: "key_server_signing_keys.key" +## Single sign-on integration ## + # Enable SAML2 for registration and login. Uses pysaml2. # -# `sp_config` is the configuration for the pysaml2 Service Provider. -# See pysaml2 docs for format of config. +# At least one of `sp_config` or `config_path` must be set in this section to +# enable SAML login. # -# Default values will be used for the 'entityid' and 'service' settings, -# so it is not normally necessary to specify them unless you need to -# override them. +# (You will probably also want to set the following options to `false` to +# disable the regular login/registration flows: +# * enable_registration +# * password_config.enabled # # Once SAML support is enabled, a metadata file will be exposed at # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to @@ -1114,52 +1404,250 @@ signing_key_path: "CONFDIR/SERVERNAME.signing.key" # the IdP to use an ACS location of # https://<server>:<port>/_matrix/saml2/authn_response. # -#saml2_config: -# sp_config: -# # point this to the IdP's metadata. You can use either a local file or -# # (preferably) a URL. -# metadata: -# #local: ["saml2/idp.xml"] -# remote: -# - url: https://our_idp/metadata.xml -# -# # By default, the user has to go to our login page first. If you'd like to -# # allow IdP-initiated login, set 'allow_unsolicited: True' in a -# # 'service.sp' section: -# # -# #service: -# # sp: -# # allow_unsolicited: True -# -# # The examples below are just used to generate our metadata xml, and you -# # may well not need it, depending on your setup. Alternatively you -# # may need a whole lot more detail - see the pysaml2 docs! -# -# description: ["My awesome SP", "en"] -# name: ["Test SP", "en"] -# -# organization: -# name: Example com -# display_name: -# - ["Example co", "en"] -# url: "http://example.com" -# -# contact_person: -# - given_name: Bob -# sur_name: "the Sysadmin" -# email_address": ["admin@example.com"] -# contact_type": technical -# -# # Instead of putting the config inline as above, you can specify a -# # separate pysaml2 configuration file: -# # -# config_path: "CONFDIR/sp_conf.py" -# -# # the lifetime of a SAML session. This defines how long a user has to -# # complete the authentication process, if allow_unsolicited is unset. -# # The default is 5 minutes. -# # -# # saml_session_lifetime: 5m +saml2_config: + # `sp_config` is the configuration for the pysaml2 Service Provider. + # See pysaml2 docs for format of config. + # + # Default values will be used for the 'entityid' and 'service' settings, + # so it is not normally necessary to specify them unless you need to + # override them. + # + #sp_config: + # # point this to the IdP's metadata. You can use either a local file or + # # (preferably) a URL. + # metadata: + # #local: ["saml2/idp.xml"] + # remote: + # - url: https://our_idp/metadata.xml + # + # # By default, the user has to go to our login page first. If you'd like + # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a + # # 'service.sp' section: + # # + # #service: + # # sp: + # # allow_unsolicited: true + # + # # The examples below are just used to generate our metadata xml, and you + # # may well not need them, depending on your setup. Alternatively you + # # may need a whole lot more detail - see the pysaml2 docs! + # + # description: ["My awesome SP", "en"] + # name: ["Test SP", "en"] + # + # organization: + # name: Example com + # display_name: + # - ["Example co", "en"] + # url: "http://example.com" + # + # contact_person: + # - given_name: Bob + # sur_name: "the Sysadmin" + # email_address": ["admin@example.com"] + # contact_type": technical + + # Instead of putting the config inline as above, you can specify a + # separate pysaml2 configuration file: + # + #config_path: "CONFDIR/sp_conf.py" + + # The lifetime of a SAML session. This defines how long a user has to + # complete the authentication process, if allow_unsolicited is unset. + # The default is 5 minutes. + # + #saml_session_lifetime: 5m + + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider + + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. + # + # This setting controls the SAML attribute which will be used for this + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. + # + # The default is 'uid'. + # + #grandfathered_mxid_source_attribute: upn + + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page to display to users if something goes wrong during the + # authentication process: 'saml_error.html'. + # + # When rendering, this template is given the following variables: + # * code: an HTML error code corresponding to the error that is being + # returned (typically 400 or 500) + # + # * msg: a textual message describing the error. + # + # The variables will automatically be HTML-escaped. + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + + +# OpenID Connect integration. The following settings can be used to make Synapse +# use an OpenID Connect Provider for authentication, instead of its internal +# password database. +# +# See https://github.com/matrix-org/synapse/blob/master/openid.md. +# +oidc_config: + # Uncomment the following to enable authorization against an OpenID Connect + # server. Defaults to false. + # + #enabled: true + + # Uncomment the following to disable use of the OIDC discovery mechanism to + # discover endpoints. Defaults to true. + # + #discover: false + + # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to + # discover the provider's endpoints. + # + # Required if 'enabled' is true. + # + #issuer: "https://accounts.example.com/" + + # oauth2 client id to use. + # + # Required if 'enabled' is true. + # + #client_id: "provided-by-your-issuer" + + # oauth2 client secret to use. + # + # Required if 'enabled' is true. + # + #client_secret: "provided-by-your-issuer" + + # auth method to use when exchanging the token. + # Valid values are 'client_secret_basic' (default), 'client_secret_post' and + # 'none'. + # + #client_auth_method: client_secret_post + + # list of scopes to request. This should normally include the "openid" scope. + # Defaults to ["openid"]. + # + #scopes: ["openid", "profile"] + + # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # + #authorization_endpoint: "https://accounts.example.com/oauth2/auth" + + # the oauth2 token endpoint. Required if provider discovery is disabled. + # + #token_endpoint: "https://accounts.example.com/oauth2/token" + + # the OIDC userinfo endpoint. Required if discovery is disabled and the + # "openid" scope is not requested. + # + #userinfo_endpoint: "https://accounts.example.com/userinfo" + + # URI where to fetch the JWKS. Required if discovery is disabled and the + # "openid" scope is used. + # + #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + + # Uncomment to skip metadata verification. Defaults to false. + # + # Use this if you are connecting to a provider that is not OpenID Connect + # compliant. + # Avoid this in production. + # + #skip_verification: true + + # An external module can be provided here as a custom solution to mapping + # attributes returned from a OIDC provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. + # + # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # for information on implementing a custom mapping provider. + # + #module: mapping_provider.OidcMappingProvider + + # Custom configuration values for the module. This section will be passed as + # a Python dictionary to the user mapping provider module's `parse_config` + # method. + # + # The examples below are intended for the default provider: they should be + # changed if using a custom provider. + # + config: + # name of the claim containing a unique identifier for the user. + # Defaults to `sub`, which OpenID Connect compliant providers should provide. + # + #subject_claim: "sub" + + # Jinja2 template for the localpart of the MXID. + # + # When rendering, this template is given the following variables: + # * user: The claims returned by the UserInfo Endpoint and/or in the ID + # Token + # + # This must be configured if using the default mapping provider. + # + localpart_template: "{{ user.preferred_username }}" + + # Jinja2 template for the display name to set on first login. + # + # If unset, no displayname will be set. + # + #display_name_template: "{{ user.given_name }} {{ user.last_name }}" @@ -1169,10 +1657,97 @@ signing_key_path: "CONFDIR/SERVERNAME.signing.key" # enabled: true # server_url: "https://cas-server.com" # service_url: "https://homeserver.domain.com:8448" +# #displayname_attribute: name # #required_attributes: # # name: value +# Additional settings to use with single-sign on systems such as OpenID Connect, +# SAML2 and CAS. +# +sso: + # A list of client URLs which are whitelisted so that the user does not + # have to confirm giving access to their account to the URL. Any client + # whose URL starts with an entry in the following list will not be subject + # to an additional confirmation step after the SSO login is completed. + # + # WARNING: An entry such as "https://my.client" is insecure, because it + # will also match "https://my.client.evil.site", exposing your users to + # phishing attacks from evil.site. To avoid this, include a slash after the + # hostname: "https://my.client/". + # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # + # By default, this list is empty. + # + #client_whitelist: + # - https://riot.im/develop + # - https://my.custom.client/ + + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page for a confirmation step before redirecting back to the client + # with the login token: 'sso_redirect_confirm.html'. + # + # When rendering, this template is given three variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * display_url: the same as `redirect_url`, but with the query + # parameters stripped. The intention is to have a + # human-readable URL to show to users, not to use it as + # the final address to redirect to. Needs manual escaping + # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # * HTML page which notifies the user that they are authenticating to confirm + # an operation on their account during the user interactive authentication + # process: 'sso_auth_confirm.html'. + # + # When rendering, this template is given the following variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * description: the operation which the user is being asked to confirm + # + # * HTML page shown after a successful user interactive authentication session: + # 'sso_auth_success.html'. + # + # Note that this page must include the JavaScript which notifies of a successful authentication + # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). + # + # This template has no additional variables. + # + # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) + # attempts to login: 'sso_account_deactivated.html'. + # + # This template has no additional variables. + # + # * HTML page to display to users if something goes wrong during the + # OpenID Connect authentication process: 'sso_error.html'. + # + # When rendering, this template is given two variables: + # * error: the technical name of the error + # * error_description: a human-readable message for the error + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + + # The JWT needs to contain a globally unique "sub" (subject) claim. # #jwt_config: @@ -1197,84 +1772,162 @@ password_config: # #pepper: "EVEN_MORE_SECRET" + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true + + +# Configuration for sending emails from Synapse. +# +email: + # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. + # + #smtp_host: mail.server + + # The port on the mail server for outgoing SMTP. Defaults to 25. + # + #smtp_port: 587 + + # Username/password for authentication to the SMTP server. By default, no + # authentication is attempted. + # + #smtp_user: "exampleusername" + #smtp_pass: "examplepassword" + + # Uncomment the following to require TLS transport security for SMTP. + # By default, Synapse will connect over plain text, and will then switch to + # TLS via STARTTLS *if the SMTP server supports it*. If this option is set, + # Synapse will refuse to connect unless the server supports STARTTLS. + # + #require_transport_security: true + + # notif_from defines the "From" address to use when sending emails. + # It must be set if email sending is enabled. + # + # The placeholder '%(app)s' will be replaced by the application name, + # which is normally 'app_name' (below), but may be overridden by the + # Matrix client application. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" + + # app_name defines the default value for '%(app)s' in notif_from. It + # defaults to 'Matrix'. + # + #app_name: my_branded_matrix_server + + # Uncomment the following to enable sending emails for messages that the user + # has missed. Disabled by default. + # + #enable_notifs: true + + # Uncomment the following to disable automatic subscription to email + # notifications for new users. Enabled by default. + # + #notif_for_new_users: false + + # Custom URL for client links within the email notifications. By default + # links will be based on "https://matrix.to". + # + # (This setting used to be called riot_base_url; the old name is still + # supported for backwards-compatibility but is now deprecated.) + # + #client_base_url: "http://localhost/riot" + # Configure the time that a validation email will expire after sending. + # Defaults to 1h. + # + #validation_token_lifetime: 15m -# Enable sending emails for password resets, notification events or -# account expiry notices -# -# If your SMTP server requires authentication, the optional smtp_user & -# smtp_pass variables should be used -# -#email: -# enable_notifs: false -# smtp_host: "localhost" -# smtp_port: 25 # SSL: 465, STARTTLS: 587 -# smtp_user: "exampleusername" -# smtp_pass: "examplepassword" -# require_transport_security: False -# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>" -# app_name: Matrix -# -# # Enable email notifications by default -# # -# notif_for_new_users: True -# -# # Defining a custom URL for Riot is only needed if email notifications -# # should contain links to a self-hosted installation of Riot; when set -# # the "app_name" setting is ignored -# # -# riot_base_url: "http://localhost/riot" -# -# # Configure the time that a validation email or text message code -# # will expire after sending -# # -# # This is currently used for password resets -# # -# #validation_token_lifetime: 1h -# -# # Template directory. All template files should be stored within this -# # directory. If not set, default templates from within the Synapse -# # package will be used -# # -# # For the list of default templates, please see -# # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates -# # -# #template_dir: res/templates -# -# # Templates for email notifications -# # -# notif_template_html: notif_mail.html -# notif_template_text: notif_mail.txt -# -# # Templates for account expiry notices -# # -# expiry_template_html: notice_expiry.html -# expiry_template_text: notice_expiry.txt -# -# # Templates for password reset emails sent by the homeserver -# # -# #password_reset_template_html: password_reset.html -# #password_reset_template_text: password_reset.txt -# -# # Templates for registration emails sent by the homeserver -# # -# #registration_template_html: registration.html -# #registration_template_text: registration.txt -# -# # Templates for password reset success and failure pages that a user -# # will see after attempting to reset their password -# # -# #password_reset_template_success_html: password_reset_success.html -# #password_reset_template_failure_html: password_reset_failure.html -# -# # Templates for registration success and failure pages that a user -# # will see after attempting to register using an email or phone -# # -# #registration_template_success_html: registration_success.html -# #registration_template_failure_html: registration_failure.html - - -#password_providers: + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * The contents of email notifications of missed events: 'notif_mail.html' and + # 'notif_mail.txt'. + # + # * The contents of account expiry notice emails: 'notice_expiry.html' and + # 'notice_expiry.txt'. + # + # * The contents of password reset emails sent by the homeserver: + # 'password_reset.html' and 'password_reset.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in the password reset email: 'password_reset_success.html' and + # 'password_reset_failure.html' + # + # * The contents of address verification emails sent during registration: + # 'registration.html' and 'registration.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in an address verification email sent during registration: + # 'registration_success.html' and 'registration_failure.html' + # + # * The contents of address verification emails sent when an address is added + # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in an address verification email sent when an address is added + # to a Matrix account: 'add_threepid_success.html' and + # 'add_threepid_failure.html' + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + + +# Password providers allow homeserver administrators to integrate +# their Synapse installation with existing authentication methods +# ex. LDAP, external tokens, etc. +# +# For more information and known implementations, please see +# https://github.com/matrix-org/synapse/blob/master/docs/password_auth_providers.md +# +# Note: instances wishing to use SAML or CAS authentication should +# instead use the `saml2_config` or `cas_config` options, +# respectively. +# +password_providers: +# # Example config for an LDAP auth provider # - module: "ldap_auth_provider.LdapAuthProvider" # config: # enabled: true @@ -1307,10 +1960,17 @@ password_config: # include_content: true -#spam_checker: -# module: "my_custom_project.SuperSpamChecker" -# config: -# example_option: 'things' +# Spam checkers are third-party modules that can block specific actions +# of local users, such as creating rooms and registering undesirable +# usernames, as well as remote users by redacting incoming events. +# +spam_checker: + #- module: "my_custom_project.SuperSpamChecker" + # config: + # example_option: 'things' + #- module: "some_other_project.BadEventStopper" + # config: + # example_stop_events_from: ['@bad:example.com'] # Uncomment to allow non-server-admin users to create groups on this server @@ -1383,11 +2043,11 @@ password_config: # body: >- # To continue using this homeserver you must review and agree to the # terms and conditions at %(consent_uri)s -# send_server_notice_to_guests: True +# send_server_notice_to_guests: true # block_events_error: >- # To continue using this homeserver you must review and agree to the # terms and conditions at %(consent_uri)s -# require_at_registration: False +# require_at_registration: false # policy_name: Privacy Policy # diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml new file mode 100644 index 0000000000..1a2739455e --- /dev/null +++ b/docs/sample_log_config.yaml @@ -0,0 +1,43 @@ +# Log configuration for Synapse. +# +# This is a YAML file containing a standard Python logging configuration +# dictionary. See [1] for details on the valid settings. +# +# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema + +version: 1 + +formatters: + precise: + format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s' + +filters: + context: + (): synapse.logging.context.LoggingContextFilter + request: "" + +handlers: + file: + class: logging.handlers.RotatingFileHandler + formatter: precise + filename: /var/log/matrix-synapse/homeserver.log + maxBytes: 104857600 + backupCount: 10 + filters: [context] + encoding: utf8 + console: + class: logging.StreamHandler + formatter: precise + filters: [context] + +loggers: + synapse.storage.SQL: + # beware: increasing this to DEBUG will make synapse log sensitive + # information such as access tokens. + level: INFO + +root: + level: INFO + handlers: [file, console] + +disable_existing_loggers: false diff --git a/docs/spam_checker.md b/docs/spam_checker.md new file mode 100644 index 0000000000..eb10e115f9 --- /dev/null +++ b/docs/spam_checker.md @@ -0,0 +1,93 @@ +# Handling spam in Synapse + +Synapse has support to customize spam checking behavior. It can plug into a +variety of events and affect how they are presented to users on your homeserver. + +The spam checking behavior is implemented as a Python class, which must be +able to be imported by the running Synapse. + +## Python spam checker class + +The Python class is instantiated with two objects: + +* Any configuration (see below). +* An instance of `synapse.spam_checker_api.SpamCheckerApi`. + +It then implements methods which return a boolean to alter behavior in Synapse. + +There's a generic method for checking every event (`check_event_for_spam`), as +well as some specific methods: + +* `user_may_invite` +* `user_may_create_room` +* `user_may_create_room_alias` +* `user_may_publish_room` + +The details of the each of these methods (as well as their inputs and outputs) +are documented in the `synapse.events.spamcheck.SpamChecker` class. + +The `SpamCheckerApi` class provides a way for the custom spam checker class to +call back into the homeserver internals. It currently implements the following +methods: + +* `get_state_events_in_room` + +### Example + +```python +class ExampleSpamChecker: + def __init__(self, config, api): + self.config = config + self.api = api + + def check_event_for_spam(self, foo): + return False # allow all events + + def user_may_invite(self, inviter_userid, invitee_userid, room_id): + return True # allow all invites + + def user_may_create_room(self, userid): + return True # allow all room creations + + def user_may_create_room_alias(self, userid, room_alias): + return True # allow all room aliases + + def user_may_publish_room(self, userid, room_id): + return True # allow publishing of all rooms + + def check_username_for_spam(self, user_profile): + return False # allow all usernames +``` + +## Configuration + +Modify the `spam_checker` section of your `homeserver.yaml` in the following +manner: + +Create a list entry with the keys `module` and `config`. + +* `module` should point to the fully qualified Python class that implements your + custom logic, e.g. `my_module.ExampleSpamChecker`. + +* `config` is a dictionary that gets passed to the spam checker class. + +### Example + +This section might look like: + +```yaml +spam_checker: + - module: my_module.ExampleSpamChecker + config: + # Enable or disable a specific option in ExampleSpamChecker. + my_custom_option: true +``` + +More spam checkers can be added in tandem by appending more items to the list. An +action is blocked when at least one of the configured spam checkers flags it. + +## Examples + +The [Mjolnir](https://github.com/matrix-org/mjolnir) project is a full fledged +example using the Synapse spam checking API, including a bot for dynamic +configuration. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md new file mode 100644 index 0000000000..abea432343 --- /dev/null +++ b/docs/sso_mapping_providers.md @@ -0,0 +1,148 @@ +# SSO Mapping Providers + +A mapping provider is a Python class (loaded via a Python module) that +works out how to map attributes of a SSO response to Matrix-specific +user attributes. Details such as user ID localpart, displayname, and even avatar +URLs are all things that can be mapped from talking to a SSO service. + +As an example, a SSO service may return the email address +"john.smith@example.com" for a user, whereas Synapse will need to figure out how +to turn that into a displayname when creating a Matrix user for this individual. +It may choose `John Smith`, or `Smith, John [Example.com]` or any number of +variations. As each Synapse configuration may want something different, this is +where SAML mapping providers come into play. + +SSO mapping providers are currently supported for OpenID and SAML SSO +configurations. Please see the details below for how to implement your own. + +External mapping providers are provided to Synapse in the form of an external +Python module. You can retrieve this module from [PyPi](https://pypi.org) or elsewhere, +but it must be importable via Synapse (e.g. it must be in the same virtualenv +as Synapse). The Synapse config is then modified to point to the mapping provider +(and optionally provide additional configuration for it). + +## OpenID Mapping Providers + +The OpenID mapping provider can be customized by editing the +`oidc_config.user_mapping_provider.module` config option. + +`oidc_config.user_mapping_provider.config` allows you to provide custom +configuration options to the module. Check with the module's documentation for +what options it provides (if any). The options listed by default are for the +user mapping provider built in to Synapse. If using a custom module, you should +comment these options out and use those specified by the module instead. + +### Building a Custom OpenID Mapping Provider + +A custom mapping provider must specify the following methods: + +* `__init__(self, parsed_config)` + - Arguments: + - `parsed_config` - A configuration object that is the return value of the + `parse_config` method. You should set any configuration options needed by + the module here. +* `parse_config(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A `dict` representing the parsed content of the + `oidc_config.user_mapping_provider.config` homeserver config option. + Runs on homeserver startup. Providers should extract and validate + any option values they need here. + - Whatever is returned will be passed back to the user mapping provider module's + `__init__` method during construction. +* `get_remote_user_id(self, userinfo)` + - Arguments: + - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user + information from. + - This method must return a string, which is the unique identifier for the + user. Commonly the ``sub`` claim of the response. +* `map_user_attributes(self, userinfo, token)` + - This method should be async. + - Arguments: + - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user + information from. + - `token` - A dictionary which includes information necessary to make + further requests to the OpenID provider. + - Returns a dictionary with two keys: + - localpart: A required string, used to generate the Matrix ID. + - displayname: An optional string, the display name for the user. + +### Default OpenID Mapping Provider + +Synapse has a built-in OpenID mapping provider if a custom provider isn't +specified in the config. It is located at +[`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`](../synapse/handlers/oidc_handler.py). + +## SAML Mapping Providers + +The SAML mapping provider can be customized by editing the +`saml2_config.user_mapping_provider.module` config option. + +`saml2_config.user_mapping_provider.config` allows you to provide custom +configuration options to the module. Check with the module's documentation for +what options it provides (if any). The options listed by default are for the +user mapping provider built in to Synapse. If using a custom module, you should +comment these options out and use those specified by the module instead. + +### Building a Custom SAML Mapping Provider + +A custom mapping provider must specify the following methods: + +* `__init__(self, parsed_config)` + - Arguments: + - `parsed_config` - A configuration object that is the return value of the + `parse_config` method. You should set any configuration options needed by + the module here. +* `parse_config(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A `dict` representing the parsed content of the + `saml_config.user_mapping_provider.config` homeserver config option. + Runs on homeserver startup. Providers should extract and validate + any option values they need here. + - Whatever is returned will be passed back to the user mapping provider module's + `__init__` method during construction. +* `get_saml_attributes(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A object resulting from a call to `parse_config`. + - Returns a tuple of two sets. The first set equates to the SAML auth + response attributes that are required for the module to function, whereas + the second set consists of those attributes which can be used if available, + but are not necessary. +* `get_remote_user_id(self, saml_response, client_redirect_url)` + - Arguments: + - `saml_response` - A `saml2.response.AuthnResponse` object to extract user + information from. + - `client_redirect_url` - A string, the URL that the client will be + redirected to. + - This method must return a string, which is the unique identifier for the + user. Commonly the ``uid`` claim of the response. +* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)` + - Arguments: + - `saml_response` - A `saml2.response.AuthnResponse` object to extract user + information from. + - `failures` - An `int` that represents the amount of times the returned + mxid localpart mapping has failed. This should be used + to create a deduplicated mxid localpart which should be + returned instead. For example, if this method returns + `john.doe` as the value of `mxid_localpart` in the returned + dict, and that is already taken on the homeserver, this + method will be called again with the same parameters but + with failures=1. The method should then return a different + `mxid_localpart` value, such as `john.doe1`. + - `client_redirect_url` - A string, the URL that the client will be + redirected to. + - This method must return a dictionary, which will then be used by Synapse + to build a new user. The following keys are allowed: + * `mxid_localpart` - Required. The mxid localpart of the new user. + * `displayname` - The displayname of the new user. If not provided, will default to + the value of `mxid_localpart`. + * `emails` - A list of emails for the new user. If not provided, will + default to an empty list. + +### Default SAML Mapping Provider + +Synapse has a built-in SAML mapping provider if a custom provider isn't +specified in the config. It is located at +[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py). diff --git a/docs/systemd-with-workers/README.md b/docs/systemd-with-workers/README.md new file mode 100644 index 0000000000..257c09446f --- /dev/null +++ b/docs/systemd-with-workers/README.md @@ -0,0 +1,67 @@ +# Setting up Synapse with Workers and Systemd + +This is a setup for managing synapse with systemd, including support for +managing workers. It provides a `matrix-synapse` service for the master, as +well as a `matrix-synapse-worker@` service template for any workers you +require. Additionally, to group the required services, it sets up a +`matrix-synapse.target`. + +See the folder [system](system) for the systemd unit files. + +The folder [workers](workers) contains an example configuration for the +`federation_reader` worker. + +## Synapse configuration files + +See [workers.md](../workers.md) for information on how to set up the +configuration files and reverse-proxy correctly. You can find an example worker +config in the [workers](workers) folder. + +Systemd manages daemonization itself, so ensure that none of the configuration +files set either `daemonize` or `worker_daemonize`. + +The config files of all workers are expected to be located in +`/etc/matrix-synapse/workers`. If you want to use a different location, edit +the provided `*.service` files accordingly. + +There is no need for a separate configuration file for the master process. + +## Set up + +1. Adjust synapse configuration files as above. +1. Copy the `*.service` and `*.target` files in [system](system) to +`/etc/systemd/system`. +1. Run `systemctl deamon-reload` to tell systemd to load the new unit files. +1. Run `systemctl enable matrix-synapse.service`. This will configure the +synapse master process to be started as part of the `matrix-synapse.target` +target. +1. For each worker process to be enabled, run `systemctl enable +matrix-synapse-worker@<worker_name>.service`. For each `<worker_name>`, there +should be a corresponding configuration file +`/etc/matrix-synapse/workers/<worker_name>.yaml`. +1. Start all the synapse processes with `systemctl start matrix-synapse.target`. +1. Tell systemd to start synapse on boot with `systemctl enable matrix-synapse.target`/ + +## Usage + +Once the services are correctly set up, you can use the following commands +to manage your synapse installation: + +```sh +# Restart Synapse master and all workers +systemctl restart matrix-synapse.target + +# Stop Synapse and all workers +systemctl stop matrix-synapse.target + +# Restart the master alone +systemctl start matrix-synapse.service + +# Restart a specific worker (eg. federation_reader); the master is +# unaffected by this. +systemctl restart matrix-synapse-worker@federation_reader.service + +# Add a new worker (assuming all configs are set up already) +systemctl enable matrix-synapse-worker@federation_writer.service +systemctl restart matrix-synapse.target +``` diff --git a/docs/systemd-with-workers/system/matrix-synapse-worker@.service b/docs/systemd-with-workers/system/matrix-synapse-worker@.service new file mode 100644 index 0000000000..39bc5e88e8 --- /dev/null +++ b/docs/systemd-with-workers/system/matrix-synapse-worker@.service @@ -0,0 +1,20 @@ +[Unit] +Description=Synapse %i +AssertPathExists=/etc/matrix-synapse/workers/%i.yaml +# This service should be restarted when the synapse target is restarted. +PartOf=matrix-synapse.target + +[Service] +Type=notify +NotifyAccess=main +User=matrix-synapse +WorkingDirectory=/var/lib/matrix-synapse +EnvironmentFile=/etc/default/matrix-synapse +ExecStart=/opt/venvs/matrix-synapse/bin/python -m synapse.app.generic_worker --config-path=/etc/matrix-synapse/homeserver.yaml --config-path=/etc/matrix-synapse/conf.d/ --config-path=/etc/matrix-synapse/workers/%i.yaml +ExecReload=/bin/kill -HUP $MAINPID +Restart=always +RestartSec=3 +SyslogIdentifier=matrix-synapse-%i + +[Install] +WantedBy=matrix-synapse.target diff --git a/contrib/systemd-with-workers/system/matrix-synapse.service b/docs/systemd-with-workers/system/matrix-synapse.service index 68e8991f18..c7b5ddfa49 100644 --- a/contrib/systemd-with-workers/system/matrix-synapse.service +++ b/docs/systemd-with-workers/system/matrix-synapse.service @@ -1,5 +1,8 @@ [Unit] -Description=Synapse Matrix Homeserver +Description=Synapse master + +# This service should be restarted when the synapse target is restarted. +PartOf=matrix-synapse.target [Service] Type=notify @@ -15,4 +18,4 @@ RestartSec=3 SyslogIdentifier=matrix-synapse [Install] -WantedBy=matrix.target +WantedBy=matrix-synapse.target diff --git a/docs/systemd-with-workers/system/matrix-synapse.target b/docs/systemd-with-workers/system/matrix-synapse.target new file mode 100644 index 0000000000..e0eba1b342 --- /dev/null +++ b/docs/systemd-with-workers/system/matrix-synapse.target @@ -0,0 +1,6 @@ +[Unit] +Description=Synapse parent target +After=network.target + +[Install] +WantedBy=multi-user.target diff --git a/contrib/systemd-with-workers/workers/federation_reader.yaml b/docs/systemd-with-workers/workers/federation_reader.yaml index 47c54ec0d4..5b65c7040d 100644 --- a/contrib/systemd-with-workers/workers/federation_reader.yaml +++ b/docs/systemd-with-workers/workers/federation_reader.yaml @@ -10,5 +10,4 @@ worker_listeners: resources: - names: [federation] -worker_daemonize: false worker_log_config: /etc/matrix-synapse/federation-reader-log.yaml diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index e099d8a87b..db318baa9d 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -14,16 +14,18 @@ example flow would be (where '>' indicates master to worker and '<' worker to master flows): > SERVER example.com - < REPLICATE events 53 - > RDATA events 54 ["$foo1:bar.com", ...] - > RDATA events 55 ["$foo4:bar.com", ...] - -The example shows the server accepting a new connection and sending its -identity with the `SERVER` command, followed by the client asking to -subscribe to the `events` stream from the token `53`. The server then -periodically sends `RDATA` commands which have the format -`RDATA <stream_name> <token> <row>`, where the format of `<row>` is -defined by the individual streams. + < REPLICATE + > POSITION events master 53 + > RDATA events master 54 ["$foo1:bar.com", ...] + > RDATA events master 55 ["$foo4:bar.com", ...] + +The example shows the server accepting a new connection and sending its identity +with the `SERVER` command, followed by the client server to respond with the +position of all streams. The server then periodically sends `RDATA` commands +which have the format `RDATA <stream_name> <instance_name> <token> <row>`, where +the format of `<row>` is defined by the individual streams. The +`<instance_name>` is the name of the Synapse process that generated the data +(usually "master"). Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -32,9 +34,6 @@ Since the protocol is a simple line based, its possible to manually connect to the server using a tool like netcat. A few things should be noted when manually using the protocol: -- When subscribing to a stream using `REPLICATE`, the special token - `NOW` can be used to get all future updates. The special stream name - `ALL` can be used with `NOW` to subscribe to all available streams. - The federation stream is only available if federation sending has been disabled on the main process. - The server will only time connections out that have sent a `PING` @@ -55,7 +54,7 @@ The basic structure of the protocol is line based, where the initial word of each line specifies the command. The rest of the line is parsed based on the command. For example, the RDATA command is defined as: - RDATA <stream_name> <token> <row_json> + RDATA <stream_name> <instance_name> <token> <row_json> (Note that <row_json> may contains spaces, but cannot contain newlines.) @@ -91,9 +90,7 @@ The client: - Sends a `NAME` command, allowing the server to associate a human friendly name with the connection. This is optional. - Sends a `PING` as above -- For each stream the client wishes to subscribe to it sends a - `REPLICATE` with the `stream_name` and token it wants to subscribe - from. +- Sends a `REPLICATE` to get the current position of all streams. - On receipt of a `SERVER` command, checks that the server name matches the expected server name. @@ -140,14 +137,12 @@ the wire: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 - > POSITION events 1 - > POSITION backfill 1 - > POSITION caches 1 - > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] - > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + < REPLICATE + > POSITION events master 1 + > POSITION backfill master 1 + > POSITION caches master 1 + > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events master 14 ["$149019767112vOHxz:localhost:8823", "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] < PING 1490197675618 > ERROR server stopping @@ -158,10 +153,10 @@ position without needing to send data with the `RDATA` command. An example of a batched set of `RDATA` is: - > RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] - > RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] + > RDATA caches master 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] In this case the client shouldn't advance their caches token until it sees the the last `RDATA`. @@ -181,9 +176,14 @@ client (C): #### POSITION (S) - The position of the stream has been updated. Sent to the client - after all missing updates for a stream have been sent to the client - and they're now up to date. + On receipt of a POSITION command clients should check if they have missed any + updates, and if so then fetch them out of band. Sent in response to a + REPLICATE command (but can happen at any time). + + The POSITION command includes the source of the stream. Currently all streams + are written by a single process (usually "master"). If fetching missing + updates via HTTP API, rather than via the DB, then processes should make the + request to the appropriate process. #### ERROR (S, C) @@ -199,11 +199,17 @@ client (C): #### REPLICATE (C) - Asks the server to replicate a given stream +Asks the server for the current position of all streams. #### USER_SYNC (C) - A user has started or stopped syncing + A user has started or stopped syncing on this process. + +#### CLEAR_USER_SYNC (C) + + The server should clear all associated user sync data from the worker. + + This is used when a worker is shutting down. #### FEDERATION_ACK (C) @@ -213,13 +219,9 @@ client (C): Inform the server a pusher should be removed -#### INVALIDATE_CACHE (C) +### REMOTE_SERVER_UP (S, C) - Inform the server a cache should be invalidated - -#### SYNC (S, C) - - Used exclusively in tests + Inform other processes that a remote server may have come back online. See `synapse/replication/tcp/commands.py` for a detailed description and the format of each command. @@ -235,7 +237,12 @@ Each individual cache invalidation results in a row being sent down replication, which includes the cache name (the name of the function) and they key to invalidate. For example: - > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] + > RDATA caches master 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] + +Alternatively, an entire cache can be invalidated by sending down a `null` +instead of the key. For example: + + > RDATA caches master 550953772 ["get_user_by_id", null, 1550574873252] However, there are times when a number of caches need to be invalidated at the same time with the same key. To reduce traffic we batch those diff --git a/docs/turn-howto.md b/docs/turn-howto.md index 4a983621e5..d4a726be66 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -11,7 +11,14 @@ TURN server. The following sections describe how to install [coturn](<https://github.com/coturn/coturn>) (which implements the TURN REST API) and integrate it with synapse. -## `coturn` Setup +## Requirements + +For TURN relaying with `coturn` to work, it must be hosted on a server/endpoint with a public IP. + +Hosting TURN behind a NAT (even with appropriate port forwarding) is known to cause issues +and to often not work. + +## `coturn` setup ### Initial installation @@ -19,7 +26,13 @@ The TURN daemon `coturn` is available from a variety of sources such as native p #### Debian installation - # apt install coturn +Just install the debian package: + +```sh +apt install coturn +``` + +This will install and start a systemd service called `coturn`. #### Source installation @@ -39,6 +52,8 @@ The TURN daemon `coturn` is available from a variety of sources such as native p make make install +### Configuration + 1. Create or edit the config file in `/etc/turnserver.conf`. The relevant lines, with example values, are: @@ -54,38 +69,52 @@ The TURN daemon `coturn` is available from a variety of sources such as native p 1. Consider your security settings. TURN lets users request a relay which will connect to arbitrary IP addresses and ports. The following configuration is suggested as a minimum starting point: - + # VoIP traffic is all UDP. There is no reason to let users connect to arbitrary TCP endpoints via the relay. no-tcp-relay - + # don't let the relay ever try to connect to private IP address ranges within your network (if any) # given the turn server is likely behind your firewall, remember to include any privileged public IPs too. denied-peer-ip=10.0.0.0-10.255.255.255 denied-peer-ip=192.168.0.0-192.168.255.255 denied-peer-ip=172.16.0.0-172.31.255.255 - + # special case the turn server itself so that client->TURN->TURN->client flows work allowed-peer-ip=10.0.0.1 - + # consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS. user-quota=12 # 4 streams per video call, so 12 streams = 3 simultaneous relayed calls per user. total-quota=1200 - Ideally coturn should refuse to relay traffic which isn't SRTP; see - <https://github.com/matrix-org/synapse/issues/2009> +1. Also consider supporting TLS/DTLS. To do this, add the following settings + to `turnserver.conf`: + + # TLS certificates, including intermediate certs. + # For Let's Encrypt certificates, use `fullchain.pem` here. + cert=/path/to/fullchain.pem + + # TLS private key file + pkey=/path/to/privkey.pem 1. Ensure your firewall allows traffic into the TURN server on the ports - you've configured it to listen on (remember to allow both TCP and UDP TURN - traffic) + you've configured it to listen on (By default: 3478 and 5349 for the TURN(s) + traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535 + for the UDP relay.) + +1. (Re)start the turn server: -1. If you've configured coturn to support TLS/DTLS, generate or import your - private key and certificate. + * If you used the Debian package (or have set up a systemd unit yourself): + ```sh + systemctl restart coturn + ``` -1. Start the turn server: + * If you installed from source: - bin/turnserver -o + ```sh + bin/turnserver -o + ``` -## synapse Setup +## Synapse setup Your home server configuration file needs the following extra keys: @@ -111,13 +140,20 @@ Your home server configuration file needs the following extra keys: As an example, here is the relevant section of the config file for matrix.org: turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ] - turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons + turn_shared_secret: "n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons" turn_user_lifetime: 86400000 turn_allow_guests: True After updating the homeserver configuration, you must restart synapse: + * If you use synctl: + ```sh cd /where/you/run/synapse ./synctl restart + ``` + * If you use systemd: + ``` + systemctl restart synapse.service + ``` ..and your Home Server now supports VoIP relaying! diff --git a/docs/user_directory.md b/docs/user_directory.md index e64aa453cc..37dc71e751 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -7,7 +7,6 @@ who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after DB corruption) get stale or out of sync. If this happens, for now the -solution to fix it is to execute the SQL here -https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/delta/53/user_dir_populate.sql +solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) and then restart synapse. This should then start a background task to flush the current tables and regenerate the directory. diff --git a/docs/workers.md b/docs/workers.md index 4bd60ba0a0..7512eff43a 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -1,23 +1,31 @@ # Scaling synapse via workers -Synapse has experimental support for splitting out functionality into -multiple separate python processes, helping greatly with scalability. These +For small instances it recommended to run Synapse in monolith mode (the +default). For larger instances where performance is a concern it can be helpful +to split out functionality into multiple separate python processes. These processes are called 'workers', and are (eventually) intended to scale horizontally independently. -All of the below is highly experimental and subject to change as Synapse evolves, -but documenting it here to help folks needing highly scalable Synapses similar -to the one running matrix.org! +Synapse's worker support is under active development and subject to change as +we attempt to rapidly scale ever larger Synapse instances. However we are +documenting it here to help admins needing a highly scalable Synapse instance +similar to the one running `matrix.org`. -All processes continue to share the same database instance, and as such, workers -only work with postgres based synapse deployments (sharing a single sqlite -across multiple processes is a recipe for disaster, plus you should be using -postgres anyway if you care about scalability). +All processes continue to share the same database instance, and as such, +workers only work with PostgreSQL-based Synapse deployments. SQLite should only +be used for demo purposes and any admin considering workers should already be +running PostgreSQL. -The workers communicate with the master synapse process via a synapse-specific -TCP protocol called 'replication' - analogous to MySQL or Postgres style -database replication; feeding a stream of relevant data to the workers so they -can be kept in sync with the main synapse process and database state. +## Master/worker communication + +The workers communicate with the master process via a Synapse-specific protocol +called 'replication' (analogous to MySQL- or Postgres-style database +replication) which feeds a stream of relevant data from the master to the +workers so they can be kept in sync with the master process and database state. + +Additionally, workers may make HTTP requests to the master, to send information +in the other direction. Typically this is used for operations which need to +wait for a reply - such as sending an event. ## Configuration @@ -27,72 +35,61 @@ the correct worker, or to the main synapse instance. Note that this includes requests made to the federation port. See [reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse proxy. -To enable workers, you need to add two replication listeners to the master -synapse, e.g.: - - listeners: - # The TCP replication port - - port: 9092 - bind_address: '127.0.0.1' - type: replication - # The HTTP replication port - - port: 9093 - bind_address: '127.0.0.1' - type: http - resources: - - names: [replication] - -Under **no circumstances** should these replication API listeners be exposed to -the public internet; it currently implements no authentication whatsoever and is -unencrypted. - -(Roughly, the TCP port is used for streaming data from the master to the -workers, and the HTTP port for the workers to send data to the main -synapse process.) - -You then create a set of configs for the various worker processes. These -should be worker configuration files, and should be stored in a dedicated -subdirectory, to allow synctl to manipulate them. An additional configuration -for the master synapse process will need to be created because the process will -not be started automatically. That configuration should look like this: - - worker_app: synapse.app.homeserver - daemonize: true - -Each worker configuration file inherits the configuration of the main homeserver -configuration file. You can then override configuration specific to that worker, -e.g. the HTTP listener that it provides (if any); logging configuration; etc. -You should minimise the number of overrides though to maintain a usable config. +To enable workers, you need to add *two* replication listeners to the +main Synapse configuration file (`homeserver.yaml`). For example: -You must specify the type of worker application (`worker_app`). The currently -available worker applications are listed below. You must also specify the -replication endpoints that it's talking to on the main synapse process. -`worker_replication_host` should specify the host of the main synapse, -`worker_replication_port` should point to the TCP replication listener port and -`worker_replication_http_port` should point to the HTTP replication port. +```yaml +listeners: + # The TCP replication port + - port: 9092 + bind_address: '127.0.0.1' + type: replication -Currently, the `event_creator` and `federation_reader` workers require specifying -`worker_replication_http_port`. + # The HTTP replication port + - port: 9093 + bind_address: '127.0.0.1' + type: http + resources: + - names: [replication] +``` -For instance: - - worker_app: synapse.app.synchrotron - - # The replication listener on the synapse to talk to. - worker_replication_host: 127.0.0.1 - worker_replication_port: 9092 - worker_replication_http_port: 9093 - - worker_listeners: - - type: http - port: 8083 - resources: - - names: - - client - - worker_daemonize: True - worker_pid_file: /home/matrix/synapse/synchrotron.pid - worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml +Under **no circumstances** should these replication API listeners be exposed to +the public internet; they have no authentication and are unencrypted. + +You should then create a set of configs for the various worker processes. Each +worker configuration file inherits the configuration of the main homeserver +configuration file. You can then override configuration specific to that +worker, e.g. the HTTP listener that it provides (if any); logging +configuration; etc. You should minimise the number of overrides though to +maintain a usable config. + +In the config file for each worker, you must specify the type of worker +application (`worker_app`). The currently available worker applications are +listed below. You must also specify the replication endpoints that it should +talk to on the main synapse process. `worker_replication_host` should specify +the host of the main synapse, `worker_replication_port` should point to the TCP +replication listener port and `worker_replication_http_port` should point to +the HTTP replication port. + +For example: + +```yaml +worker_app: synapse.app.synchrotron + +# The replication listener on the synapse to talk to. +worker_replication_host: 127.0.0.1 +worker_replication_port: 9092 +worker_replication_http_port: 9093 + +worker_listeners: + - type: http + port: 8083 + resources: + - names: + - client + +worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml +``` ...is a full configuration for a synchrotron worker instance, which will expose a plain HTTP `/sync` endpoint on port 8083 separately from the `/sync` endpoint provided @@ -101,7 +98,75 @@ by the main synapse. Obviously you should configure your reverse-proxy to route the relevant endpoints to the worker (`localhost:8083` in the above example). -Finally, to actually run your worker-based synapse, you must pass synctl the -a +Finally, you need to start your worker processes. This can be done with either +`synctl` or your distribution's preferred service manager such as `systemd`. We +recommend the use of `systemd` where available: for information on setting up +`systemd` to start synapse workers, see +[systemd-with-workers](systemd-with-workers). To use `synctl`, see below. + +### **Experimental** support for replication over redis + +As of Synapse v1.13.0, it is possible to configure Synapse to send replication +via a [Redis pub/sub channel](https://redis.io/topics/pubsub). This is an +alternative to direct TCP connections to the master: rather than all the +workers connecting to the master, all the workers and the master connect to +Redis, which relays replication commands between processes. This can give a +significant cpu saving on the master and will be a prerequisite for upcoming +performance improvements. + +Note that this support is currently experimental; you may experience lost +messages and similar problems! It is strongly recommended that admins setting +up workers for the first time use direct TCP replication as above. + +To configure Synapse to use Redis: + +1. Install Redis following the normal procedure for your distribution - for + example, on Debian, `apt install redis-server`. (It is safe to use an + existing Redis deployment if you have one: we use a pub/sub stream named + according to the `server_name` of your synapse server.) +2. Check Redis is running and accessible: you should be able to `echo PING | nc -q1 + localhost 6379` and get a response of `+PONG`. +3. Install the python prerequisites. If you installed synapse into a + virtualenv, this can be done with: + ```sh + pip install matrix-synapse[redis] + ``` + The debian packages from matrix.org already include the required + dependencies. +4. Add config to the shared configuration (`homeserver.yaml`): + ```yaml + redis: + enabled: true + ``` + Optional parameters which can go alongside `enabled` are `host`, `port`, + `password`. Normally none of these are required. +5. Restart master and all workers. + +Once redis replication is in use, `worker_replication_port` is redundant and +can be removed from the worker configuration files. Similarly, the +configuration for the `listener` for the TCP replication port can be removed +from the main configuration file. Note that the HTTP replication port is +still required. + +### Using synctl + +If you want to use `synctl` to manage your synapse processes, you will need to +create an an additional configuration file for the master synapse process. That +configuration should look like this: + +```yaml +worker_app: synapse.app.homeserver +``` + +Additionally, each worker app must be configured with the name of a "pid file", +to which it will write its process ID when it starts. For example, for a +synchrotron, you might write: + +```yaml +worker_pid_file: /home/matrix/synapse/synchrotron.pid +``` + +Finally, to actually run your worker-based synapse, you must pass synctl the `-a` commandline option to tell it to operate on all the worker configurations found in the given directory, e.g.: @@ -168,20 +233,42 @@ endpoints matching the following regular expressions: ^/_matrix/federation/v1/make_join/ ^/_matrix/federation/v1/make_leave/ ^/_matrix/federation/v1/send_join/ + ^/_matrix/federation/v2/send_join/ ^/_matrix/federation/v1/send_leave/ + ^/_matrix/federation/v2/send_leave/ ^/_matrix/federation/v1/invite/ + ^/_matrix/federation/v2/invite/ ^/_matrix/federation/v1/query_auth/ ^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/exchange_third_party_invite/ + ^/_matrix/federation/v1/user/devices/ ^/_matrix/federation/v1/send/ + ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query +Additionally, the following REST endpoints can be handled for GET requests: + + ^/_matrix/federation/v1/groups/ + The above endpoints should all be routed to the federation_reader worker by the reverse-proxy configuration. The `^/_matrix/federation/v1/send/` endpoint must only be handled by a single instance. +Note that `federation` must be added to the listener resources in the worker config: + +```yaml +worker_app: synapse.app.federation_reader +... +worker_listeners: + - type: http + port: <port> + resources: + - names: + - federation +``` + ### `synapse.app.federation_sender` Handles sending federation traffic to other servers. Doesn't handle any @@ -196,16 +283,30 @@ Handles the media repository. It can handle all endpoints starting with: /_matrix/media/ -And the following regular expressions matching media-specific administration APIs: +... and the following regular expressions matching media-specific administration APIs: ^/_synapse/admin/v1/purge_media_cache$ - ^/_synapse/admin/v1/room/.*/media$ + ^/_synapse/admin/v1/room/.*/media.*$ + ^/_synapse/admin/v1/user/.*/media.*$ + ^/_synapse/admin/v1/media/.*$ ^/_synapse/admin/v1/quarantine_media/.*$ You should also set `enable_media_repo: False` in the shared configuration file to stop the main synapse running background jobs related to managing the media repository. +In the `media_repository` worker configuration file, configure the http listener to +expose the `media` resource. For example: + +```yaml + worker_listeners: + - type: http + port: 8085 + resources: + - names: + - media +``` + Note this worker cannot be load-balanced: only one instance should be active. ### `synapse.app.client_reader` @@ -224,15 +325,22 @@ following regular expressions: ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|unstable)/voip/turnServer$ + ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ + ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ + ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ Additionally, the following REST endpoints can be handled for GET requests: ^/_matrix/client/(api/v1|r0|unstable)/pushrules/.*$ + ^/_matrix/client/(api/v1|r0|unstable)/groups/.*$ + ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/account_data/ + ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/rooms/[^/]*/account_data/ Additionally, the following REST endpoints can be handled, but all requests must be routed to the same instance: ^/_matrix/client/(r0|unstable)/register$ + ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$ Pagination requests can also be handled, but all requests with the same path room must be routed to the same instance. Additionally, care must be taken to @@ -248,6 +356,10 @@ the following regular expressions: ^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$ +When using this worker you must also set `update_user_directory: False` in the +shared configuration file to stop the main synapse running background +jobs related to updating the user directory. + ### `synapse.app.frontend_proxy` Proxies some frequently-requested client endpoints to add caching and remove @@ -276,6 +388,7 @@ file. For example: Handles some event creation. It can handle REST endpoints matching: ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ ^/_matrix/client/(api/v1|r0|unstable)/join/ ^/_matrix/client/(api/v1|r0|unstable)/profile/ diff --git a/mypy.ini b/mypy.ini index 8788574ee3..3533797d68 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,12 +1,14 @@ [mypy] -namespace_packages=True -plugins=mypy_zope:plugin -follow_imports=skip -mypy_path=stubs - -[mypy-synapse.config.homeserver] -# this is a mess because of the metaclass shenanigans -ignore_errors = True +namespace_packages = True +plugins = mypy_zope:plugin +follow_imports = silent +check_untyped_defs = True +show_error_codes = True +show_traceback = True +mypy_path = stubs + +[mypy-pymacaroons.*] +ignore_missing_imports = True [mypy-zope] ignore_missing_imports = True @@ -52,3 +54,27 @@ ignore_missing_imports = True [mypy-signedjson.*] ignore_missing_imports = True + +[mypy-prometheus_client.*] +ignore_missing_imports = True + +[mypy-service_identity.*] +ignore_missing_imports = True + +[mypy-daemonize] +ignore_missing_imports = True + +[mypy-sentry_sdk] +ignore_missing_imports = True + +[mypy-PIL.*] +ignore_missing_imports = True + +[mypy-lxml] +ignore_missing_imports = True + +[mypy-jwt.*] +ignore_missing_imports = True + +[mypy-authlib.*] +ignore_missing_imports = True diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 93305ee9b1..e6f4bd1dca 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -20,11 +20,12 @@ from concurrent.futures import ThreadPoolExecutor DISTS = ( "debian:stretch", "debian:buster", + "debian:bullseye", "debian:sid", "ubuntu:xenial", "ubuntu:bionic", - "ubuntu:cosmic", - "ubuntu:disco", + "ubuntu:eoan", + "ubuntu:focal", ) DESC = '''\ diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index 0ec5075e79..98a618f6b2 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -7,7 +7,9 @@ set -e # make sure that origin/develop is up to date git remote set-branches --add origin develop -git fetch origin develop +git fetch -q origin develop + +pr="$BUILDKITE_PULL_REQUEST" # if there are changes in the debian directory, check that the debian changelog # has been updated @@ -20,20 +22,30 @@ fi # if there are changes *outside* the debian directory, check that the # newsfragments have been updated. -if git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then - tox -e check-newsfragment +if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then + exit 0 fi +tox -qe check-newsfragment + echo echo "--------------------------" echo -# check that any new newsfiles on this branch end with a full stop. +matched=0 for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do + # check that any modified newsfiles on this branch end with a full stop. lastchar=`tr -d '\n' < $f | tail -c 1` if [ $lastchar != '.' -a $lastchar != '!' ]; then echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2 exit 1 fi + + # see if this newsfile corresponds to the right PR + [[ -n "$pr" && "$f" == changelog.d/"$pr".* ]] && matched=1 done +if [[ -n "$pr" && "$matched" -eq 0 ]]; then + echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2 + exit 1 +fi diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py deleted file mode 100644 index 2a1c5f39d4..0000000000 --- a/scripts-dev/check_auth.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import print_function - -import argparse -import itertools -import json -import sys - -from mock import Mock - -from synapse.api.auth import Auth -from synapse.events import FrozenEvent - - -def check_auth(auth, auth_chain, events): - auth_chain.sort(key=lambda e: e.depth) - - auth_map = {e.event_id: e for e in auth_chain} - - create_events = {} - for e in auth_chain: - if e.type == "m.room.create": - create_events[e.room_id] = e - - for e in itertools.chain(auth_chain, events): - auth_events_list = [auth_map[i] for i, _ in e.auth_events] - - auth_events = {(e.type, e.state_key): e for e in auth_events_list} - - auth_events[("m.room.create", "")] = create_events[e.room_id] - - try: - auth.check(e, auth_events=auth_events) - except Exception as ex: - print("Failed:", e.event_id, e.type, e.state_key) - print("Auth_events:", auth_events) - print(ex) - print(json.dumps(e.get_dict(), sort_keys=True, indent=4)) - # raise - print("Success:", e.event_id, e.type, e.state_key) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "json", nargs="?", type=argparse.FileType("r"), default=sys.stdin - ) - - args = parser.parse_args() - - js = json.load(args.json) - - auth = Auth(Mock()) - check_auth( - auth, - [FrozenEvent(d) for d in js["auth_chain"]], - [FrozenEvent(d) for d in js.get("pdus", [])], - ) diff --git a/scripts-dev/config-lint.sh b/scripts-dev/config-lint.sh new file mode 100755 index 0000000000..189ca66535 --- /dev/null +++ b/scripts-dev/config-lint.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Find linting errors in Synapse's default config file. +# Exits with 0 if there are no problems, or another code otherwise. + +# Fix non-lowercase true/false values +sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml +rm docs/sample_config.yaml.bak + +# Check if anything changed +git diff --exit-code docs/sample_config.yaml diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 179be61c30..961dc59f11 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -3,8 +3,6 @@ import json import sys import time -import six - import psycopg2 import yaml from canonicaljson import encode_canonical_json @@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -if six.PY2: - db_type = six.moves.builtins.buffer -else: - db_type = memoryview +db_binary_type = memoryview def select_v1_keys(connection): @@ -72,7 +67,7 @@ def rows_v2(server, json): valid_until = json["valid_until_ts"] key_json = encode_canonical_json(json) for key_id in json["verify_keys"]: - yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) + yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json)) def main(): @@ -103,7 +98,7 @@ def main(): yaml.safe_dump(result, sys.stdout, default_flow_style=False) - rows = list(row for server, json in result.items() for row in rows_v2(server, json)) + rows = [row for server, json in result.items() for row in rows_v2(server, json)] cursor = connection.cursor() cursor.executemany( diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config index 5e33b9b549..9cb4630a5c 100755 --- a/scripts-dev/generate_sample_config +++ b/scripts-dev/generate_sample_config @@ -7,12 +7,22 @@ set -e cd `dirname $0`/.. SAMPLE_CONFIG="docs/sample_config.yaml" +SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml" + +check() { + diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || return 1 +} if [ "$1" == "--check" ]; then diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || { echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 exit 1 } + diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || { + echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 + exit 1 + } else ./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" + ./scripts/generate_log_config -o "$SAMPLE_LOG_CONFIG" fi diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index d20f6db176..bf3862a386 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -27,7 +27,7 @@ class Store(object): "_store_pdu_reference_hash_txn" ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] + simple_insert_txn = SQLBaseStore.__dict__["simple_insert_txn"] store = Store() diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index ebb4d69f86..34c4854e11 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -7,6 +7,15 @@ set -e -isort -y -rc synapse tests scripts-dev scripts -flake8 synapse tests -python3 -m black synapse tests scripts-dev scripts +if [ $# -ge 1 ] +then + files=$* +else + files="synapse tests scripts-dev scripts" +fi + +echo "Linting these locations: $files" +isort -y -rc $files +flake8 $files +python3 -m black $files +./scripts-dev/config-lint.sh diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh new file mode 100755 index 0000000000..60e8970a35 --- /dev/null +++ b/scripts-dev/make_full_schema.sh @@ -0,0 +1,184 @@ +#!/bin/bash +# +# This script generates SQL files for creating a brand new Synapse DB with the latest +# schema, on both SQLite3 and Postgres. +# +# It does so by having Synapse generate an up-to-date SQLite DB, then running +# synapse_port_db to convert it to Postgres. It then dumps the contents of both. + +POSTGRES_HOST="localhost" +POSTGRES_DB_NAME="synapse_full_schema.$$" + +SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite" +POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres" + +REQUIRED_DEPS=("matrix-synapse" "psycopg2") + +usage() { + echo + echo "Usage: $0 -p <postgres_username> -o <path> [-c] [-n] [-h]" + echo + echo "-p <postgres_username>" + echo " Username to connect to local postgres instance. The password will be requested" + echo " during script execution." + echo "-c" + echo " CI mode. Enables coverage tracking and prints every command that the script runs." + echo "-o <path>" + echo " Directory to output full schema files to." + echo "-h" + echo " Display this help text." +} + +while getopts "p:co:h" opt; do + case $opt in + p) + POSTGRES_USERNAME=$OPTARG + ;; + c) + # Print all commands that are being executed + set -x + + # Modify required dependencies for coverage + REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess") + + COVERAGE=1 + ;; + o) + command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1) + OUTPUT_DIR="$(realpath "$OPTARG")" + ;; + h) + usage + exit + ;; + \?) + echo "ERROR: Invalid option: -$OPTARG" >&2 + usage + exit + ;; + esac +done + +# Check that required dependencies are installed +unsatisfied_requirements=() +for dep in "${REQUIRED_DEPS[@]}"; do + pip show "$dep" --quiet || unsatisfied_requirements+=("$dep") +done +if [ ${#unsatisfied_requirements} -ne 0 ]; then + echo "Please install the following python packages: ${unsatisfied_requirements[*]}" + exit 1 +fi + +if [ -z "$POSTGRES_USERNAME" ]; then + echo "No postgres username supplied" + usage + exit 1 +fi + +if [ -z "$OUTPUT_DIR" ]; then + echo "No output directory supplied" + usage + exit 1 +fi + +# Create the output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD +echo "" + +# Exit immediately if a command fails +set -e + +# cd to root of the synapse directory +cd "$(dirname "$0")/.." + +# Create temporary SQLite and Postgres homeserver db configs and key file +TMPDIR=$(mktemp -d) +KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path +SQLITE_CONFIG=$TMPDIR/sqlite.conf +SQLITE_DB=$TMPDIR/homeserver.db +POSTGRES_CONFIG=$TMPDIR/postgres.conf + +# Ensure these files are delete on script exit +trap 'rm -rf $TMPDIR' EXIT + +cat > "$SQLITE_CONFIG" <<EOF +server_name: "test" + +signing_key_path: "$KEY_FILE" +macaroon_secret_key: "abcde" + +report_stats: false + +database: + name: "sqlite3" + args: + database: "$SQLITE_DB" + +# Suppress the key server warning. +trusted_key_servers: [] +EOF + +cat > "$POSTGRES_CONFIG" <<EOF +server_name: "test" + +signing_key_path: "$KEY_FILE" +macaroon_secret_key: "abcde" + +report_stats: false + +database: + name: "psycopg2" + args: + user: "$POSTGRES_USERNAME" + host: "$POSTGRES_HOST" + password: "$POSTGRES_PASSWORD" + database: "$POSTGRES_DB_NAME" + +# Suppress the key server warning. +trusted_key_servers: [] +EOF + +# Generate the server's signing key. +echo "Generating SQLite3 db schema..." +python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG" + +# Make sure the SQLite3 database is using the latest schema and has no pending background update. +echo "Running db background jobs..." +scripts-dev/update_database --database-config "$SQLITE_CONFIG" + +# Create the PostgreSQL database. +echo "Creating postgres database..." +createdb $POSTGRES_DB_NAME + +echo "Copying data from SQLite3 to Postgres with synapse_port_db..." +if [ -z "$COVERAGE" ]; then + # No coverage needed + scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" +else + # Coverage desired + coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" +fi + +# Delete schema_version, applied_schema_deltas and applied_module_schemas tables +# This needs to be done after synapse_port_db is run +echo "Dropping unwanted db tables..." +SQL=" +DROP TABLE schema_version; +DROP TABLE applied_schema_deltas; +DROP TABLE applied_module_schemas; +" +sqlite3 "$SQLITE_DB" <<< "$SQL" +psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL" + +echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE'..." +sqlite3 "$SQLITE_DB" ".dump" > "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE" + +echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..." +pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE" + +echo "Cleaning up temporary Postgres database..." +dropdb $POSTGRES_DB_NAME + +echo "Done! Files dumped to: $OUTPUT_DIR" diff --git a/scripts-dev/update_database b/scripts-dev/update_database new file mode 100755 index 0000000000..94aa8758b4 --- /dev/null +++ b/scripts-dev/update_database @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import sys + +import yaml + +from twisted.internet import defer, reactor + +import synapse +from synapse.config.homeserver import HomeServerConfig +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.server import HomeServer +from synapse.storage import DataStore +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger("update_database") + + +class MockHomeserver(HomeServer): + DATASTORE_CLASS = DataStore + + def __init__(self, config, **kwargs): + super(MockHomeserver, self).__init__( + config.server_name, reactor=reactor, config=config, **kwargs + ) + + self.version_string = "Synapse/"+get_version_string(synapse) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Updates a synapse database to the latest schema and runs background updates" + " on it." + ) + ) + parser.add_argument("-v", action="store_true") + parser.add_argument( + "--database-config", + type=argparse.FileType("r"), + required=True, + help="A database config file for either a SQLite3 database or a PostgreSQL one.", + ) + + args = parser.parse_args() + + logging_config = { + "level": logging.DEBUG if args.v else logging.INFO, + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + } + + logging.basicConfig(**logging_config) + + # Load, process and sanity-check the config. + hs_config = yaml.safe_load(args.database_config) + + if "database" not in hs_config: + sys.stderr.write("The configuration file must have a 'database' section.\n") + sys.exit(4) + + config = HomeServerConfig() + config.parse_config_dict(hs_config, "", "") + + # Instantiate and initialise the homeserver object. + hs = MockHomeserver(config) + + # Setup instantiates the store within the homeserver object and updates the + # DB. + hs.setup() + store = hs.get_datastore() + + async def run_background_updates(): + await store.db.updates.run_background_updates(sleep=False) + # Stop the reactor to exit the script once every background update is run. + reactor.stop() + + def run(): + # Apply all background updates on the database. + defer.ensureDeferred( + run_as_background_process("background_updates", run_background_updates) + ) + + reactor.callWhenRunning(run) + + reactor.run() diff --git a/scripts/export_signing_key b/scripts/export_signing_key new file mode 100755 index 0000000000..8aec9d802b --- /dev/null +++ b/scripts/export_signing_key @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import sys +import time +from typing import Optional + +import nacl.signing +from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys + + +def exit(status: int = 0, message: Optional[str] = None): + if message: + print(message, file=sys.stderr) + sys.exit(status) + + +def format_plain(public_key: nacl.signing.VerifyKey): + print( + "%s:%s %s" + % (public_key.alg, public_key.version, encode_verify_key_base64(public_key),) + ) + + +def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): + print( + ' "%s:%s": { key: "%s", expired_ts: %i }' + % ( + public_key.alg, + public_key.version, + encode_verify_key_base64(public_key), + expiry_ts, + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read", + ) + + parser.add_argument( + "-x", + action="store_true", + dest="for_config", + help="format the output for inclusion in the old_signing_keys config setting", + ) + + parser.add_argument( + "--expiry-ts", + type=int, + default=int(time.time() * 1000) + 6*3600000, + help=( + "The expiry time to use for -x, in milliseconds since 1970. The default " + "is (now+6h)." + ), + ) + + args = parser.parse_args() + + formatter = ( + (lambda k: format_for_config(k, args.expiry_ts)) + if args.for_config + else format_plain + ) + + keys = [] + for file in args.key_file: + try: + res = read_signing_keys(file) + except Exception as e: + exit( + status=1, + message="Error reading key from file %s: %s %s" + % (file.name, type(e), e), + ) + res = [] + for key in res: + formatter(get_verify_key(key)) diff --git a/scripts/generate_log_config b/scripts/generate_log_config new file mode 100755 index 0000000000..b6957f48a3 --- /dev/null +++ b/scripts/generate_log_config @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +from synapse.config.logger import DEFAULT_LOG_CONFIG + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-o", + "--output-file", + type=argparse.FileType("w"), + default=sys.stdout, + help="File to write the configuration to. Default: stdout", + ) + + parser.add_argument( + "-f", + "--log-file", + type=str, + default="/var/log/matrix-synapse/homeserver.log", + help="name of the log file", + ) + + args = parser.parse_args() + args.output_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) diff --git a/scripts/hash_password b/scripts/hash_password index a1eb0769da..a30767f758 100755 --- a/scripts/hash_password +++ b/scripts/hash_password @@ -52,7 +52,7 @@ if __name__ == "__main__": if "config" in args and args.config: config = yaml.safe_load(args.config) bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds) - password_config = config.get("password_config", {}) + password_config = config.get("password_config", None) or {} password_pepper = password_config.get("pepper", password_pepper) password = args.password diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index 12747c6024..b5b63933ab 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths): # check that the original exists original_file = src_paths.remote_media_filepath(origin_server, file_id) if not os.path.exists(original_file): - logger.warn( + logger.warning( "Original for %s/%s (%s) does not exist", origin_server, file_id, diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index b6ba19c776..9a0fbc61d8 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,12 +27,44 @@ from six import string_types import yaml -from twisted.enterprise import adbapi from twisted.internet import defer, reactor -from synapse.storage._base import LoggingTransaction, SQLBaseStore +import synapse +from synapse.config.database import DatabaseConnectionConfig +from synapse.config.homeserver import HomeServerConfig +from synapse.logging.context import ( + LoggingContext, + make_deferred_yieldable, + run_in_background, +) +from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.data_stores.main.deviceinbox import ( + DeviceInboxBackgroundUpdateStore, +) +from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore +from synapse.storage.data_stores.main.events_bg_updates import ( + EventsBackgroundUpdatesStore, +) +from synapse.storage.data_stores.main.media_repository import ( + MediaRepositoryBackgroundUpdateStore, +) +from synapse.storage.data_stores.main.registration import ( + RegistrationBackgroundUpdateStore, +) +from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore +from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore +from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore +from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore +from synapse.storage.data_stores.main.stats import StatsStore +from synapse.storage.data_stores.main.user_directory import ( + UserDirectoryBackgroundUpdateStore, +) +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +from synapse.util import Clock +from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse_port_db") @@ -43,6 +76,7 @@ BOOLEAN_COLUMNS = { "presence_list": ["accepted"], "presence_stream": ["currently_active"], "public_room_list_stream": ["visibility"], + "devices": ["hidden"], "device_lists_outbound_pokes": ["sent"], "users_who_share_rooms": ["share_private"], "groups": ["is_public"], @@ -55,6 +89,8 @@ BOOLEAN_COLUMNS = { "local_group_membership": ["is_publicised", "is_admin"], "e2e_room_keys": ["is_verified"], "account_validity": ["email_sent"], + "redactions": ["have_censored"], + "room_stats_state": ["is_federatable"], } @@ -86,79 +122,47 @@ APPEND_ONLY_TABLES = [ "presence_stream", "push_rules_stream", "ex_outlier_stream", - "cache_invalidation_stream", + "cache_invalidation_stream_by_instance", "public_room_list_stream", "state_group_edges", "stream_ordering_to_exterm", ] +# Error returned by the run function. Used at the top-level part of the script to +# handle errors and return codes. +end_error = None +# The exec_info for the error, if any. If error is defined but not exec_info the script +# will show only the error message without the stacktrace, if exec_info is defined but +# not the error then the script will show nothing outside of what's printed in the run +# function. If both are defined, the script will print both the error and the stacktrace. end_error_exec_info = None -class Store(object): - """This object is used to pull out some of the convenience API from the - Storage layer. - - *All* database interactions should go through this object. - """ - - def __init__(self, db_pool, engine): - self.db_pool = db_pool - self.database_engine = engine - - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] - _simple_insert = SQLBaseStore.__dict__["_simple_insert"] - - _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] - _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] - _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"] - _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"] - _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] - _simple_select_one_onecol_txn = SQLBaseStore.__dict__[ - "_simple_select_one_onecol_txn" - ] - - _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] - _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] - _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"] - - def runInteraction(self, desc, func, *args, **kwargs): - def r(conn): - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - return func( - LoggingTransaction(txn, desc, self.database_engine, [], []), - *args, - **kwargs - ) - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) - if i < N: - i += 1 - conn.rollback() - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", desc, e) - raise - - return self.db_pool.runWithConnection(r) - +class Store( + ClientIpBackgroundUpdateStore, + DeviceInboxBackgroundUpdateStore, + DeviceBackgroundUpdateStore, + EventsBackgroundUpdatesStore, + MediaRepositoryBackgroundUpdateStore, + RegistrationBackgroundUpdateStore, + RoomBackgroundUpdateStore, + RoomMemberBackgroundUpdateStore, + SearchBackgroundUpdateStore, + StateBackgroundUpdateStore, + MainStateBackgroundUpdateStore, + UserDirectoryBackgroundUpdateStore, + StatsStore, +): def execute(self, f, *args, **kwargs): - return self.runInteraction(f.__name__, f, *args, **kwargs) + return self.db.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.runInteraction("execute_sql", r) + return self.db.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -173,16 +177,37 @@ class Store(object): logger.exception("Failed to insert: %s", table) raise + def set_room_is_public(self, room_id, is_public): + raise Exception( + "Attempt to set room_is_public during port_db: database not empty?" + ) + + +class MockHomeserver: + def __init__(self, config): + self.clock = Clock(reactor) + self.config = config + self.hostname = config.server_name + self.version_string = "Synapse/" + get_version_string(synapse) + + def get_clock(self): + return self.clock + + def get_reactor(self): + return reactor + + def get_instance_name(self): + return "master" + class Porter(object): def __init__(self, **kwargs): self.__dict__.update(kwargs) - @defer.inlineCallbacks - def setup_table(self, table): + async def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store._simple_select_one( + row = await self.postgres_store.db.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -192,12 +217,14 @@ class Porter(object): total_to_port = None if row is None: if table == "sent_transactions": - forward_chunk, already_ported, total_to_port = ( - yield self._setup_sent_transactions() - ) + ( + forward_chunk, + already_ported, + total_to_port, + ) = await self._setup_sent_transactions() backward_chunk = 0 else: - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -214,7 +241,7 @@ class Porter(object): backward_chunk = row["backward_rowid"] if total_to_port is None: - already_ported, total_to_port = yield self._get_total_count_to_port( + already_ported, total_to_port = await self._get_total_count_to_port( table, forward_chunk, backward_chunk ) else: @@ -225,9 +252,9 @@ class Porter(object): ) txn.execute("TRUNCATE %s CASCADE" % (table,)) - yield self.postgres_store.execute(delete_all) + await self.postgres_store.execute(delete_all) - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -235,16 +262,13 @@ class Porter(object): forward_chunk = 1 backward_chunk = 0 - already_ported, total_to_port = yield self._get_total_count_to_port( + already_ported, total_to_port = await self._get_total_count_to_port( table, forward_chunk, backward_chunk ) - defer.returnValue( - (table, already_ported, total_to_port, forward_chunk, backward_chunk) - ) + return table, already_ported, total_to_port, forward_chunk, backward_chunk - @defer.inlineCallbacks - def handle_table( + async def handle_table( self, table, postgres_size, table_size, forward_chunk, backward_chunk ): logger.info( @@ -262,7 +286,7 @@ class Porter(object): self.progress.add_table(table, postgres_size, table_size) if table == "event_search": - yield self.handle_search_table( + await self.handle_search_table( postgres_size, table_size, forward_chunk, backward_chunk ) return @@ -281,7 +305,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -322,7 +346,9 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) + headers, frows, brows = await self.sqlite_store.db.runInteraction( + "select", r + ) if frows or brows: if frows: @@ -336,7 +362,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -346,7 +372,7 @@ class Porter(object): }, ) - yield self.postgres_store.execute(insert) + await self.postgres_store.execute(insert) postgres_size += len(rows) @@ -354,8 +380,7 @@ class Porter(object): else: return - @defer.inlineCallbacks - def handle_search_table( + async def handle_search_table( self, postgres_size, table_size, forward_chunk, backward_chunk ): select = ( @@ -375,7 +400,7 @@ class Porter(object): return headers, rows - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = await self.sqlite_store.db.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -392,8 +417,8 @@ class Porter(object): rows_dict = [] for row in rows: d = dict(zip(headers, row)) - if "\0" in d['value']: - logger.warn('dropping search row %s', d) + if "\0" in d["value"]: + logger.warning("dropping search row %s", d) else: rows_dict.append(d) @@ -413,7 +438,7 @@ class Porter(object): ], ) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -423,7 +448,7 @@ class Porter(object): }, ) - yield self.postgres_store.execute(insert) + await self.postgres_store.execute(insert) postgres_size += len(rows) @@ -432,44 +457,85 @@ class Porter(object): else: return - def setup_db(self, db_config, database_engine): - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) + def build_db_store( + self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False, + ): + """Builds and returns a database store using the provided configuration. - prepare_database(db_conn, database_engine, config=None) + Args: + db_config: The database configuration + allow_outdated_version: True to suppress errors about the database server + version being too old to run a complete synapse - db_conn.commit() + Returns: + The built Store object. + """ + self.progress.set_state("Preparing %s" % db_config.config["name"]) - @defer.inlineCallbacks - def run(self): - try: - sqlite_db_pool = adbapi.ConnectionPool( - self.sqlite_config["name"], **self.sqlite_config["args"] + engine = create_engine(db_config.config) + + hs = MockHomeserver(self.hs_config) + + with make_conn(db_config, engine) as db_conn: + engine.check_database( + db_conn, allow_outdated_version=allow_outdated_version ) + prepare_database(db_conn, engine, config=self.hs_config) + store = Store(Database(hs, db_config, engine), db_conn, hs) + db_conn.commit() + + return store - postgres_db_pool = adbapi.ConnectionPool( - self.postgres_config["name"], **self.postgres_config["args"] + async def run_background_updates_on_postgres(self): + # Manually apply all background updates on the PostgreSQL database. + postgres_ready = ( + await self.postgres_store.db.updates.has_completed_background_updates() + ) + + if not postgres_ready: + # Only say that we're running background updates when there are background + # updates to run. + self.progress.set_state("Running background updates on PostgreSQL") + + while not postgres_ready: + await self.postgres_store.db.updates.do_next_background_update(100) + postgres_ready = await ( + self.postgres_store.db.updates.has_completed_background_updates() ) - sqlite_engine = create_engine(sqlite_config) - postgres_engine = create_engine(postgres_config) + async def run(self): + """Ports the SQLite database to a PostgreSQL database. - self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) - self.postgres_store = Store(postgres_db_pool, postgres_engine) + When a fatal error is met, its message is assigned to the global "end_error" + variable. When this error comes with a stacktrace, its exec_info is assigned to + the global "end_error_exec_info" variable. + """ + global end_error + + try: + # we allow people to port away from outdated versions of sqlite. + self.sqlite_store = self.build_db_store( + DatabaseConnectionConfig("master-sqlite", self.sqlite_config), + allow_outdated_version=True, + ) - yield self.postgres_store.execute(postgres_engine.check_database) + # Check if all background updates are done, abort if not. + updates_complete = ( + await self.sqlite_store.db.updates.has_completed_background_updates() + ) + if not updates_complete: + end_error = ( + "Pending background updates exist in the SQLite3 database." + " Please start Synapse again and wait until every update has finished" + " before running this script.\n" + ) + return - # Step 1. Set up databases. - self.progress.set_state("Preparing SQLite3") - self.setup_db(sqlite_config, sqlite_engine) + self.postgres_store = self.build_db_store( + self.hs_config.get_single_database() + ) - self.progress.set_state("Preparing PostgreSQL") - self.setup_db(postgres_config, postgres_engine) + await self.run_background_updates_on_postgres() self.progress.set_state("Creating port tables") @@ -497,22 +563,22 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction("alter_table", alter_table) + await self.postgres_store.db.runInteraction("alter_table", alter_table) except Exception: # On Error Resume Next pass - yield self.postgres_store.runInteraction( + await self.postgres_store.db.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store._simple_select_onecol( + sqlite_tables = await self.sqlite_store.db.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store._simple_select_onecol( + postgres_tables = await self.postgres_store.db.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -523,28 +589,34 @@ class Porter(object): # Step 3. Figure out what still needs copying self.progress.set_state("Checking on port progress") - setup_res = yield defer.gatherResults( - [ - self.setup_table(table) - for table in tables - if table not in ["schema_version", "applied_schema_deltas"] - and not table.startswith("sqlite_") - ], - consumeErrors=True, + setup_res = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(self.setup_table, table) + for table in tables + if table not in ["schema_version", "applied_schema_deltas"] + and not table.startswith("sqlite_") + ], + consumeErrors=True, + ) ) # Step 4. Do the copying. self.progress.set_state("Copying to postgres") - yield defer.gatherResults( - [self.handle_table(*res) for res in setup_res], consumeErrors=True + await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(self.handle_table, *res) for res in setup_res], + consumeErrors=True, + ) ) # Step 5. Do final post-processing - yield self._setup_state_group_id_seq() + await self._setup_state_group_id_seq() self.progress.done() - except Exception: + except Exception as e: global end_error_exec_info + end_error = e end_error_exec_info = sys.exc_info() logger.exception("") finally: @@ -561,8 +633,10 @@ class Porter(object): def conv(j, col): if j in bool_cols: return bool(col) + if isinstance(col, bytes): + return bytearray(col) elif isinstance(col, string_types) and "\0" in col: - logger.warn( + logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], @@ -582,8 +656,7 @@ class Porter(object): return outrows - @defer.inlineCallbacks - def _setup_sent_transactions(self): + async def _setup_sent_transactions(self): # Only save things from the last day yesterday = int(time.time() * 1000) - 86400000 @@ -600,11 +673,11 @@ class Porter(object): rows = txn.fetchall() headers = [column[0] for column in txn.description] - ts_ind = headers.index('ts') + ts_ind = headers.index("ts") return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = await self.sqlite_store.db.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -617,7 +690,7 @@ class Porter(object): txn, "sent_transactions", headers[1:], rows ) - yield self.postgres_store.execute(insert) + await self.postgres_store.execute(insert) else: max_inserted_rowid = 0 @@ -634,10 +707,10 @@ class Porter(object): else: return 1 - next_chunk = yield self.sqlite_store.execute(get_start_id) + next_chunk = await self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -650,57 +723,63 @@ class Porter(object): txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - size, = txn.fetchone() + (size,) = txn.fetchone() return int(size) - remaining_count = yield self.sqlite_store.execute(get_sent_table_size) + remaining_count = await self.sqlite_store.execute(get_sent_table_size) total_count = remaining_count + inserted_rows - defer.returnValue((next_chunk, inserted_rows, total_count)) + return next_chunk, inserted_rows, total_count - @defer.inlineCallbacks - def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): - frows = yield self.sqlite_store.execute_sql( + async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): + frows = await self.sqlite_store.execute_sql( "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk ) - brows = yield self.sqlite_store.execute_sql( + brows = await self.sqlite_store.execute_sql( "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk ) - defer.returnValue(frows[0][0] + brows[0][0]) + return frows[0][0] + brows[0][0] - @defer.inlineCallbacks - def _get_already_ported_count(self, table): - rows = yield self.postgres_store.execute_sql( + async def _get_already_ported_count(self, table): + rows = await self.postgres_store.execute_sql( "SELECT count(*) FROM %s" % (table,) ) - defer.returnValue(rows[0][0]) + return rows[0][0] - @defer.inlineCallbacks - def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): - remaining, done = yield defer.gatherResults( - [ - self._get_remaining_count_to_port(table, forward_chunk, backward_chunk), - self._get_already_ported_count(table), - ], - consumeErrors=True, + async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): + remaining, done = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._get_remaining_count_to_port, + table, + forward_chunk, + backward_chunk, + ), + run_in_background(self._get_already_ported_count, table), + ], + ) ) remaining = int(remaining) if remaining else 0 done = int(done) if done else 0 - defer.returnValue((done, remaining + done)) + return done, remaining + done def _setup_state_group_id_seq(self): def r(txn): txn.execute("SELECT MAX(id) FROM state_groups") - next_id = txn.fetchone()[0] + 1 + curr_id = txn.fetchone()[0] + if not curr_id: + return + next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) ############################################## @@ -781,7 +860,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds) + duration_str = "%02dm %02ds" % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -791,7 +870,7 @@ class CursesProgress(Progress): left = float(self.total_remaining) / self.total_processed est_remaining = (int(now) - self.start_time) * left - est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) + est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" status = "Time spent: %s (est. remaining: %s)" % ( @@ -877,7 +956,7 @@ if __name__ == "__main__": description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--sqlite-database", required=True, @@ -886,12 +965,12 @@ if __name__ == "__main__": ) parser.add_argument( "--postgres-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', help="display a curses based progress UI" + "--curses", action="store_true", help="display a curses based progress UI" ) parser.add_argument( @@ -924,18 +1003,24 @@ if __name__ == "__main__": }, } - postgres_config = yaml.safe_load(args.postgres_config) + hs_config = yaml.safe_load(args.postgres_config) + + if "database" not in hs_config: + sys.stderr.write("The configuration file must have a 'database' section.\n") + sys.exit(4) - if "database" in postgres_config: - postgres_config = postgres_config["database"] + postgres_config = hs_config["database"] if "name" not in postgres_config: - sys.stderr.write("Malformed database config: no 'name'") + sys.stderr.write("Malformed database config: no 'name'\n") sys.exit(2) if postgres_config["name"] != "psycopg2": - sys.stderr.write("Database must use 'psycopg2' connector.") + sys.stderr.write("Database must use the 'psycopg2' connector.\n") sys.exit(3) + config = HomeServerConfig() + config.parse_config_dict(hs_config, "", "") + def start(stdscr=None): if stdscr: progress = CursesProgress(stdscr) @@ -944,12 +1029,17 @@ if __name__ == "__main__": porter = Porter( sqlite_config=sqlite_config, - postgres_config=postgres_config, progress=progress, batch_size=args.batch_size, + hs_config=config, ) - reactor.callWhenRunning(porter.run) + @defer.inlineCallbacks + def run(): + with LoggingContext("synapse_port_db_run"): + yield defer.ensureDeferred(porter.run()) + + reactor.callWhenRunning(run) reactor.run() @@ -958,6 +1048,11 @@ if __name__ == "__main__": else: start() - if end_error_exec_info: - exc_type, exc_value, exc_traceback = end_error_exec_info - traceback.print_exception(exc_type, exc_value, exc_traceback) + if end_error: + if end_error_exec_info: + exc_type, exc_value, exc_traceback = end_error_exec_info + traceback.print_exception(exc_type, exc_value, exc_traceback) + + sys.stderr.write(end_error) + + sys.exit(5) diff --git a/setup.py b/setup.py index 5ce06c8987..54ddec8f9f 100755 --- a/setup.py +++ b/setup.py @@ -114,6 +114,7 @@ setup( "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], scripts=["synctl"] + glob.glob("scripts/*"), cmdclass={"test": TestCommand}, diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml new file mode 100644 index 0000000000..9a01152c15 --- /dev/null +++ b/snap/snapcraft.yaml @@ -0,0 +1,57 @@ +name: matrix-synapse +base: core18 +version: git +summary: Reference Matrix homeserver +description: | + Synapse is the reference Matrix homeserver. + Matrix is a federated and decentralised instant messaging and VoIP system. + +grade: stable +confinement: strict + +apps: + matrix-synapse: + command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml + stop-command: synctl -c $SNAP_COMMON stop + plugs: [network-bind, network] + daemon: simple + hash-password: + command: hash_password + generate-config: + command: generate_config + generate-signing-key: + command: generate_signing_key.py + register-new-matrix-user: + command: register_new_matrix_user + plugs: [network] + synctl: + command: synctl +parts: + matrix-synapse: + source: . + plugin: python + python-version: python3 + python-packages: + - '.[all]' + - pip + - setuptools + - setuptools-scm + - wheel + build-packages: + - libffi-dev + - libturbojpeg0-dev + - libssl-dev + - libxslt1-dev + - libpq-dev + - zlib1g-dev + stage-packages: + - libasn1-8-heimdal + - libgssapi3-heimdal + - libhcrypto4-heimdal + - libheimbase1-heimdal + - libheimntlm0-heimdal + - libhx509-5-heimdal + - libkrb5-26-heimdal + - libldap-2.4-2 + - libpq5 + - libsasl2-2 diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi new file mode 100644 index 0000000000..073b806d3c --- /dev/null +++ b/stubs/sortedcontainers/__init__.pyi @@ -0,0 +1,13 @@ +from .sorteddict import ( + SortedDict, + SortedKeysView, + SortedItemsView, + SortedValuesView, +) + +__all__ = [ + "SortedDict", + "SortedKeysView", + "SortedItemsView", + "SortedValuesView", +] diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi new file mode 100644 index 0000000000..68779f968e --- /dev/null +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -0,0 +1,124 @@ +# stub for SortedDict. This is a lightly edited copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/eea42df1f7bad2792e8da77335ff888f04b9e5ae/sortedcontainers/sorteddict.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterator, + Iterable, + ItemsView, + KeysView, + List, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Tuple, + Union, + ValuesView, + overload, +) + +_T = TypeVar("_T") +_S = TypeVar("_S") +_T_h = TypeVar("_T_h", bound=Hashable) +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. +_KT_co = TypeVar("_KT_co", covariant=True, bound=Hashable) +_VT_co = TypeVar("_VT_co", covariant=True) +_SD = TypeVar("_SD", bound=SortedDict) +_Key = Callable[[_T], Any] + +class SortedDict(Dict[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @overload + def __init__(self, __key: _Key[_KT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __map: Mapping[_KT, _VT], **kwargs: _VT + ) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @property + def key(self) -> Optional[_Key[_KT]]: ... + @property + def iloc(self) -> SortedKeysView[_KT]: ... + def clear(self) -> None: ... + def __delitem__(self, key: _KT) -> None: ... + def __iter__(self) -> Iterator[_KT]: ... + def __reversed__(self) -> Iterator[_KT]: ... + def __setitem__(self, key: _KT, value: _VT) -> None: ... + def _setitem(self, key: _KT, value: _VT) -> None: ... + def copy(self: _SD) -> _SD: ... + def __copy__(self: _SD) -> _SD: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... + def keys(self) -> SortedKeysView[_KT]: ... + def items(self) -> SortedItemsView[_KT, _VT]: ... + def values(self) -> SortedValuesView[_VT]: ... + @overload + def pop(self, key: _KT) -> _VT: ... + @overload + def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ... + def popitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def peekitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ... + @overload + def update(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def update(self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT) -> None: ... + @overload + def update(self, **kwargs: _VT) -> None: ... + def __reduce__( + self, + ) -> Tuple[ + Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], + ]: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_KT]: ... + def bisect_left(self, value: _KT) -> int: ... + def bisect_right(self, value: _KT) -> int: ... + +class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]): + @overload + def __getitem__(self, index: int) -> _KT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_KT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedItemsView( # type: ignore + ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]] +): + def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ... + @overload + def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ... + @overload + def __getitem__(self, index: slice) -> List[Tuple[_KT_co, _VT_co]]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]): + @overload + def __getitem__(self, index: int) -> _VT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_VT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi new file mode 100644 index 0000000000..cac689d4f3 --- /dev/null +++ b/stubs/txredisapi.pyi @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains *incomplete* type hints for txredisapi. +""" + +from typing import List, Optional, Union + +class RedisProtocol: + def publish(self, channel: str, message: bytes): ... + +class SubscriberProtocol: + password: Optional[str] + def subscribe(self, channels: Union[str, List[str]]): ... + def connectionMade(self): ... + def connectionLost(self, reason): ... + +def lazyConnection( + host: str = ..., + port: int = ..., + dbid: Optional[int] = ..., + reconnect: bool = ..., + charset: str = ..., + password: Optional[str] = ..., + connectTimeout: Optional[int] = ..., + replyTimeout: Optional[int] = ..., + convertNumbers: bool = ..., +) -> RedisProtocol: ... + +class SubscriberFactory: + def buildProtocol(self, addr): ... diff --git a/synapse/__init__.py b/synapse/__init__.py index 6766ef445c..39c3a3ba4d 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" This is a reference implementation of a Matrix home server. +""" This is a reference implementation of a Matrix homeserver. """ +import os import sys # Check that we're not running on an unsupported Python version. @@ -35,4 +36,11 @@ try: except ImportError: pass -__version__ = "1.3.1" +__version__ = "1.15.0rc1" + +if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): + # We import here so that we don't have to install a bunch of deps when + # running the packaging tox test. + from synapse.util.patch_inline_callbacks import do_patch + + do_patch() diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index bdcd915bbe..d528450c78 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -144,8 +144,8 @@ def main(): logging.captureWarnings(True) parser = argparse.ArgumentParser( - description="Used to register new users with a given home server when" - " registration has been disabled. The home server must be" + description="Used to register new users with a given homeserver when" + " registration has been disabled. The homeserver must be" " configured with the 'registration_shared_secret' option" " set." ) @@ -202,7 +202,7 @@ def main(): "server_url", default="https://localhost:8448", nargs="?", - help="URL to use to talk to the home server. Defaults to " + help="URL to use to talk to the homeserver. Defaults to " " 'https://localhost:8448'.", ) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9e445cd808..06ade25674 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from six import itervalues @@ -21,21 +22,23 @@ import pymacaroons from netaddr import IPAddress from twisted.internet import defer +from twisted.web.server import Request import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import EventTypes, JoinRules, Membership, UserTypes +from synapse.api.auth_blocking import AuthBlocking +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, InvalidClientTokenError, MissingClientTokenError, - ResourceLimitError, ) -from synapse.config.server import is_threepid_reserved -from synapse.types import UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import StateMap, UserID +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -71,55 +74,58 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) + self.token_cache = LruCache(10000) register_cache("cache", "token_cache", self.token_cache) + self._auth_blocking = AuthBlocking(self.hs) + self._account_validity = hs.config.account_validity + self._track_appservice_user_ips = hs.config.track_appservice_user_ips + self._macaroon_secret_key = hs.config.macaroon_secret_key @defer.inlineCallbacks - def check_from_context(self, room_version, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids(self.store) + def check_from_context(self, room_version: str, event, context, do_sig_check=True): + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} - self.check( - room_version, event, auth_events=auth_events, do_sig_check=do_sig_check - ) - def check(self, room_version, event, auth_events, do_sig_check=True): - """ Checks if this event is correctly authed. + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + event_auth.check( + room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check + ) + @defer.inlineCallbacks + def check_user_in_room( + self, + room_id: str, + user_id: str, + current_state: Optional[StateMap[EventBase]] = None, + allow_departed_users: bool = False, + ): + """Check if the user is in the room, or was at some point. Args: - room_version (str): version of the room - event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + room_id: The room to check. + user_id: The user to check. - Returns: - True if the auth checks pass. - """ - with Measure(self.clock, "auth.check"): - event_auth.check( - room_version, event, auth_events, do_sig_check=do_sig_check - ) - - @defer.inlineCallbacks - def check_joined_room(self, room_id, user_id, current_state=None): - """Check if the user is currently joined in the room - Args: - room_id(str): The room to check. - user_id(str): The user to check. - current_state(dict): Optional map of the current state of the room. + current_state: Optional map of the current state of the room. If provided then that map is used to check whether they are a member of the room. Otherwise the current membership is loaded from the database. + + allow_departed_users: if True, accept users that were previously + members but have now departed. + Raises: - AuthError if the user is not in the room. + AuthError if the user is/was not in the room. Returns: - A deferred membership event for the user if the user is in - the room. + Deferred[Optional[EventBase]]: + Membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. """ if current_state: member = current_state.get((EventTypes.Member, user_id), None) @@ -127,37 +133,19 @@ class Auth(object): member = yield self.state.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) - - self._check_joined_room(member, user_id, room_id) - return member - - @defer.inlineCallbacks - def check_user_was_in_room(self, room_id, user_id): - """Check if the user was in the room at some point. - Args: - room_id(str): The room to check. - user_id(str): The user to check. - Raises: - AuthError if the user was never in the room. - Returns: - A deferred membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. - """ - member = yield self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) membership = member.membership if member else None - if membership not in (Membership.JOIN, Membership.LEAVE): - raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) + if membership == Membership.JOIN: + return member - if membership == Membership.LEAVE: + # XXX this looks totally bogus. Why do we not allow users who have been banned, + # or those who were members previously and have been re-invited? + if allow_departed_users and membership == Membership.LEAVE: forgot = yield self.store.did_forget(user_id, room_id) - if forgot: - raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) + if not forgot: + return member - return member + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) @defer.inlineCallbacks def check_host_in_room(self, room_id, host): @@ -165,12 +153,6 @@ class Auth(object): latest_event_ids = yield self.store.is_host_joined(room_id, host) return latest_event_ids - def _check_joined_room(self, member, user_id, room_id): - if not member or member.membership != Membership.JOIN: - raise AuthError( - 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) - ) - def can_federate(self, event, auth_events): creation_event = auth_events.get((EventTypes.Create, "")) @@ -179,22 +161,27 @@ class Auth(object): def get_public_keys(self, invite_event): return event_auth.get_public_keys(invite_event) - @opentracing.trace @defer.inlineCallbacks def get_user_by_req( - self, request, allow_guest=False, rights="access", allow_expired=False + self, + request: Request, + allow_guest: bool = False, + rights: str = "access", + allow_expired: bool = False, ): """ Get a registered user's ID. Args: - request - An HTTP request with an access_token query parameter. - allow_expired - Whether to allow the request through even if the account is - expired. If true, Synapse will still require an access token to be - provided but won't check if the account it belongs to has expired. This - works thanks to /login delivering access tokens regardless of accounts' - expiration. + request: An HTTP request with an access_token query parameter. + allow_guest: If False, will raise an AuthError if the user making the + request is a guest. + rights: The operation being performed; the access token must allow this + allow_expired: If True, allow the request through even if the account + is expired, or session token lifetime has ended. Note that + /login will deliver access tokens regardless of expiration. + Returns: - defer.Deferred: resolves to a ``synapse.types.Requester`` object + defer.Deferred: resolves to a `synapse.types.Requester` object Raises: InvalidClientCredentialsError if no user by that token exists or the token is invalid. @@ -212,8 +199,9 @@ class Auth(object): if user_id: request.authenticated_entity = user_id opentracing.set_tag("authenticated_entity", user_id) + opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self.hs.config.track_appservice_user_ips: + if ip_addr and self._track_appservice_user_ips: yield self.store.insert_client_ip( user_id=user_id, access_token=access_token, @@ -224,7 +212,9 @@ class Auth(object): return synapse.types.create_requester(user_id, app_service=app_service) - user_info = yield self.get_user_by_access_token(access_token, rights) + user_info = yield self.get_user_by_access_token( + access_token, rights, allow_expired=allow_expired + ) user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] @@ -263,6 +253,8 @@ class Auth(object): request.authenticated_entity = user.to_string() opentracing.set_tag("authenticated_entity", user.to_string()) + if device_id: + opentracing.set_tag("device_id", device_id) return synapse.types.create_requester( user, token_id, is_guest, device_id, app_service=app_service @@ -297,13 +289,17 @@ class Auth(object): return user_id, app_service @defer.inlineCallbacks - def get_user_by_access_token(self, token, rights="access"): + def get_user_by_access_token( + self, token: str, rights: str = "access", allow_expired: bool = False, + ): """ Validate access token and get user_id from it Args: - token (str): The access token to get the user by. - rights (str): The operation being performed; the access token must - allow this. + token: The access token to get the user by + rights: The operation being performed; the access token must + allow this + allow_expired: If False, raises an InvalidClientTokenError + if the token is expired Returns: Deferred[dict]: dict that includes: `user` (UserID) @@ -311,8 +307,10 @@ class Auth(object): `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: + InvalidClientTokenError if a user by that token exists, but the token is + expired InvalidClientCredentialsError if no user by that token exists or the token - is invalid. + is invalid """ if rights == "access": @@ -321,7 +319,8 @@ class Auth(object): if r: valid_until_ms = r["valid_until_ms"] if ( - valid_until_ms is not None + not allow_expired + and valid_until_ms is not None and valid_until_ms < self.clock.time_msec() ): # there was a valid access token, but it has expired. @@ -474,7 +473,7 @@ class Auth(object): # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + v.verify(macaroon, self._macaroon_secret_key) def _verify_expiry(self, caveat): prefix = "time < " @@ -506,109 +505,77 @@ class Auth(object): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: - logger.warn("Unrecognised appservice access token.") + logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender return defer.succeed(service) - def is_server_admin(self, user): + async def is_server_admin(self, user: UserID) -> bool: """ Check if the given user is a local server admin. Args: - user (UserID): user to check + user: user to check Returns: - bool: True if the user is an admin + True if the user is an admin """ - return self.store.is_server_admin(user) - - @defer.inlineCallbacks - def compute_auth_events(self, event, current_state_ids, for_verification=False): - if event.type == EventTypes.Create: - return [] - - auth_ids = [] + return await self.store.is_server_admin(user) - key = (EventTypes.PowerLevels, "") - power_level_event_id = current_state_ids.get(key) - - if power_level_event_id: - auth_ids.append(power_level_event_id) - - key = (EventTypes.JoinRules, "") - join_rule_event_id = current_state_ids.get(key) - - key = (EventTypes.Member, event.sender) - member_event_id = current_state_ids.get(key) + def compute_auth_events( + self, event, current_state_ids: StateMap[str], for_verification: bool = False, + ): + """Given an event and current state return the list of event IDs used + to auth an event. - key = (EventTypes.Create, "") - create_event_id = current_state_ids.get(key) - if create_event_id: - auth_ids.append(create_event_id) + If `for_verification` is False then only return auth events that + should be added to the event's `auth_events`. - if join_rule_event_id: - join_rule_event = yield self.store.get_event(join_rule_event_id) - join_rule = join_rule_event.content.get("join_rule") - is_public = join_rule == JoinRules.PUBLIC if join_rule else False - else: - is_public = False + Returns: + defer.Deferred(list[str]): List of event IDs. + """ - if event.type == EventTypes.Member: - e_type = event.content["membership"] - if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event_id: - auth_ids.append(join_rule_event_id) + if event.type == EventTypes.Create: + return defer.succeed([]) + + # Currently we ignore the `for_verification` flag even though there are + # some situations where we can drop particular auth events when adding + # to the event's `auth_events` (e.g. joins pointing to previous joins + # when room is publically joinable). Dropping event IDs has the + # advantage that the auth chain for the room grows slower, but we use + # the auth chain in state resolution v2 to order events, which means + # care must be taken if dropping events to ensure that it doesn't + # introduce undesirable "state reset" behaviour. + # + # All of which sounds a bit tricky so we don't bother for now. - if e_type == Membership.JOIN: - if member_event_id and not is_public: - auth_ids.append(member_event_id) - else: - if member_event_id: - auth_ids.append(member_event_id) - - if for_verification: - key = (EventTypes.Member, event.state_key) - existing_event_id = current_state_ids.get(key) - if existing_event_id: - auth_ids.append(existing_event_id) - - if e_type == Membership.INVITE: - if "third_party_invite" in event.content: - key = ( - EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"], - ) - third_party_invite_id = current_state_ids.get(key) - if third_party_invite_id: - auth_ids.append(third_party_invite_id) - elif member_event_id: - member_event = yield self.store.get_event(member_event_id) - if member_event.content["membership"] == Membership.JOIN: - auth_ids.append(member_event.event_id) + auth_ids = [] + for etype, state_key in event_auth.auth_types_for_event(event): + auth_ev_id = current_state_ids.get((etype, state_key)) + if auth_ev_id: + auth_ids.append(auth_ev_id) - return auth_ids + return defer.succeed(auth_ids) - @defer.inlineCallbacks - def check_can_change_room_list(self, room_id, user): - """Check if the user is allowed to edit the room's entry in the + async def check_can_change_room_list(self, room_id: str, user: UserID): + """Determine whether the user is allowed to edit the room's entry in the published room list. Args: - room_id (str) - user (UserID) + room_id + user """ - is_admin = yield self.is_server_admin(user) + is_admin = await self.is_server_admin(user) if is_admin: return True user_id = user.to_string() - yield self.check_joined_room(room_id, user_id) + await self.check_user_in_room(room_id, user_id) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the - # m.room.aliases events - power_level_event = yield self.state.get_current_state( + # m.room.canonical_alias events + power_level_event = await self.state.get_current_state( room_id, EventTypes.PowerLevels, "" ) @@ -617,19 +584,14 @@ class Auth(object): auth_events[(EventTypes.PowerLevels, "")] = power_level_event send_level = event_auth.get_send_level( - EventTypes.Aliases, "", power_level_event + EventTypes.CanonicalAlias, "", power_level_event ) user_level = event_auth.get_user_power_level(user_id, auth_events) - if user_level < send_level: - raise AuthError( - 403, - "This server requires you to be a moderator in the room to" - " edit its room list entry", - ) + return user_level >= send_level @staticmethod - def has_access_token(request): + def has_access_token(request: Request): """Checks if the request has an access_token. Returns: @@ -640,7 +602,7 @@ class Auth(object): return bool(query_params) or bool(auth_headers) @staticmethod - def get_access_token_from_request(request): + def get_access_token_from_request(request: Request): """Extracts the access_token from the request. Args: @@ -676,10 +638,18 @@ class Auth(object): return query_params[0].decode("ascii") @defer.inlineCallbacks - def check_in_room_or_world_readable(self, room_id, user_id): + def check_user_in_room_or_world_readable( + self, room_id: str, user_id: str, allow_departed_users: bool = False + ): """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. + Args: + room_id: room to check + user_id: user to check + allow_departed_users: if True, accept users that were previously + members but have now departed + Returns: Deferred[tuple[str, str|None]]: Resolves to the current membership of the user in the room and the membership event ID of the user. If @@ -688,12 +658,14 @@ class Auth(object): """ try: - # check_user_was_in_room will return the most recent membership + # check_user_in_room will return the most recent membership # event for the user if: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.check_user_was_in_room(room_id, user_id) + member_event = yield self.check_user_in_room( + room_id, user_id, allow_departed_users=allow_departed_users + ) return member_event.membership, member_event.event_id except AuthError: visibility = yield self.state.get_current_state( @@ -705,74 +677,10 @@ class Auth(object): ): return Membership.JOIN, None raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN - ) - - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): - """Checks if the user should be rejected for some external reason, - such as monthly active user limiting or global disable flag - - Args: - user_id(str|None): If present, checks for presence against existing - MAU cohort - - threepid(dict|None): If present, checks for presence against configured - reserved threepid. Used in cases where the user is trying register - with a MAU blocked server, normally they would be rejected but their - threepid is on the reserved list. user_id and - threepid should never be set at the same time. - - user_type(str|None): If present, is used to decide whether to check against - certain blocking reasons like MAU. - """ - - # Never fail an auth check for the server notices users or support user - # This can be a problem where event creation is prohibited due to blocking - if user_id is not None: - if user_id == self.hs.config.server_notices_mxid: - return - if (yield self.store.is_support_user(user_id)): - return - - if self.hs.config.hs_disabled: - raise ResourceLimitError( 403, - self.hs.config.hs_disabled_message, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type, + "User %s not in room %s, and room previews are disabled" + % (user_id, room_id), ) - if self.hs.config.limit_usage_by_mau is True: - assert not (user_id and threepid) - # If the user is already part of the MAU cohort or a trial user - if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) - if timestamp: - return - - is_trial = yield self.store.is_trial_user(user_id) - if is_trial: - return - elif threepid: - # If the user does not exist yet, but is signing up with a - # reserved threepid then pass auth check - if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid - ): - return - elif user_type == UserTypes.SUPPORT: - # If the user does not exist yet and is of type "support", - # allow registration. Support users are excluded from MAU checks. - return - # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() - if current_mau >= self.hs.config.max_mau_value: - raise ResourceLimitError( - 403, - "Monthly Active User Limit Exceeded", - admin_contact=self.hs.config.admin_contact, - errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user", - ) + def check_auth_blocking(self, *args, **kwargs): + return self._auth_blocking.check_auth_blocking(*args, **kwargs) diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py new file mode 100644 index 0000000000..5c499b6b4e --- /dev/null +++ b/synapse/api/auth_blocking.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import LimitBlockingTypes, UserTypes +from synapse.api.errors import Codes, ResourceLimitError +from synapse.config.server import is_threepid_reserved + +logger = logging.getLogger(__name__) + + +class AuthBlocking(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + self._server_notices_mxid = hs.config.server_notices_mxid + self._hs_disabled = hs.config.hs_disabled + self._hs_disabled_message = hs.config.hs_disabled_message + self._admin_contact = hs.config.admin_contact + self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids + + @defer.inlineCallbacks + def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + """Checks if the user should be rejected for some external reason, + such as monthly active user limiting or global disable flag + + Args: + user_id(str|None): If present, checks for presence against existing + MAU cohort + + threepid(dict|None): If present, checks for presence against configured + reserved threepid. Used in cases where the user is trying register + with a MAU blocked server, normally they would be rejected but their + threepid is on the reserved list. user_id and + threepid should never be set at the same time. + + user_type(str|None): If present, is used to decide whether to check against + certain blocking reasons like MAU. + """ + + # Never fail an auth check for the server notices users or support user + # This can be a problem where event creation is prohibited due to blocking + if user_id is not None: + if user_id == self._server_notices_mxid: + return + if (yield self.store.is_support_user(user_id)): + return + + if self._hs_disabled: + raise ResourceLimitError( + 403, + self._hs_disabled_message, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + admin_contact=self._admin_contact, + limit_type=LimitBlockingTypes.HS_DISABLED, + ) + if self._limit_usage_by_mau is True: + assert not (user_id and threepid) + + # If the user is already part of the MAU cohort or a trial user + if user_id: + timestamp = yield self.store.user_last_seen_monthly_active(user_id) + if timestamp: + return + + is_trial = yield self.store.is_trial_user(user_id) + if is_trial: + return + elif threepid: + # If the user does not exist yet, but is signing up with a + # reserved threepid then pass auth check + if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid): + return + elif user_type == UserTypes.SUPPORT: + # If the user does not exist yet and is of type "support", + # allow registration. Support users are excluded from MAU checks. + return + # Else if there is no room in the MAU bucket, bail + current_mau = yield self.store.get_monthly_active_count() + if current_mau >= self._max_mau_value: + raise ResourceLimitError( + 403, + "Monthly Active User Limit Exceeded", + admin_contact=self._admin_contact, + errcode=Codes.RESOURCE_LIMIT_EXCEEDED, + limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER, + ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f29bce560c..5ec4a77ccd 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,12 +61,9 @@ class LoginType(object): MSISDN = "m.login.msisdn" RECAPTCHA = "m.login.recaptcha" TERMS = "m.login.terms" + SSO = "m.login.sso" DUMMY = "m.login.dummy" - # Only for C/S API v1 - APPLICATION_SERVICE = "m.login.application_service" - SHARED_SECRET = "org.matrix.login.shared_secret" - class EventTypes(object): Member = "m.room.member" @@ -76,12 +74,11 @@ class EventTypes(object): Aliases = "m.room.aliases" Redaction = "m.room.redaction" ThirdPartyInvite = "m.room.third_party_invite" - Encryption = "m.room.encryption" RelatedGroups = "m.room.related_groups" RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" - Encryption = "m.room.encryption" + Encrypted = "m.room.encrypted" RoomAvatar = "m.room.avatar" RoomEncryption = "m.room.encryption" GuestAccess = "m.room.guest_access" @@ -94,11 +91,13 @@ class EventTypes(object): ServerACL = "m.room.server_acl" Pinned = "m.room.pinned_events" + Retention = "m.room.retention" + + Presence = "m.presence" + class RejectedReason(object): AUTH_ERROR = "auth_error" - REPLACED = "replaced" - NOT_ANCESTOR = "not_ancestor" class RoomCreationPreset(object): @@ -133,3 +132,21 @@ class RelationTypes(object): ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" + + +class LimitBlockingTypes(object): + """Reasons that a server may be blocked""" + + MONTHLY_ACTIVE_USER = "monthly_active_user" + HS_DISABLED = "hs_disabled" + + +class EventContentFields(object): + """Fields found in events' content, regardless of type.""" + + # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 + LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index cf1ebf1af2..d54dfb385d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,12 +17,15 @@ """Contains exceptions and error codes.""" import logging +from typing import Dict, List from six import iteritems from six.moves import http_client from canonicaljson import json +from twisted.web import http + logger = logging.getLogger(__name__) @@ -61,7 +64,16 @@ class Codes(object): INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" + PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT" + PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT" + PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE" + PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE" + PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL" + PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY" + WEAK_PASSWORD = "M_WEAK_PASSWORD" + INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + BAD_ALIAS = "M_BAD_ALIAS" class CodeMessageException(RuntimeError): @@ -74,10 +86,40 @@ class CodeMessageException(RuntimeError): def __init__(self, code, msg): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) - self.code = code + + # Some calls to this method pass instances of http.HTTPStatus for `code`. + # While HTTPStatus is a subclass of int, it has magic __str__ methods + # which emit `HTTPStatus.FORBIDDEN` when converted to a str, instead of `403`. + # This causes inconsistency in our log lines. + # + # To eliminate this behaviour, we convert them to their integer equivalents here. + self.code = int(code) self.msg = msg +class RedirectException(CodeMessageException): + """A pseudo-error indicating that we want to redirect the client to a different + location + + Attributes: + cookies: a list of set-cookies values to add to the response. For example: + b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT" + """ + + def __init__(self, location: bytes, http_code: int = http.FOUND): + """ + + Args: + location: the URI to redirect to + http_code: the HTTP response code + """ + msg = "Redirect to %s" % (location.decode("utf-8"),) + super().__init__(code=http_code, msg=msg) + self.location = location + + self.cookies = [] # type: List[bytes] + + class SynapseError(CodeMessageException): """A base exception type for matrix errors which have an errcode and error message (as well as an HTTP status code). @@ -111,7 +153,7 @@ class ProxiedRequestError(SynapseError): def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): super(ProxiedRequestError, self).__init__(code, msg, errcode) if additional_fields is None: - self._additional_fields = {} + self._additional_fields = {} # type: Dict else: self._additional_fields = dict(additional_fields) @@ -156,12 +198,6 @@ class UserDeactivatedError(SynapseError): ) -class RegistrationError(SynapseError): - """An error raised when a registration event fails.""" - - pass - - class FederationDeniedError(SynapseError): """An error raised when the server tries to federate with a server which is not on its federation whitelist. @@ -381,11 +417,9 @@ class UnsupportedRoomVersionError(SynapseError): """The client's request to create a room used a room version that the server does not support.""" - def __init__(self): + def __init__(self, msg="Homeserver does not support this room version"): super(UnsupportedRoomVersionError, self).__init__( - code=400, - msg="Homeserver does not support this room version", - errcode=Codes.UNSUPPORTED_ROOM_VERSION, + code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -419,6 +453,20 @@ class IncompatibleRoomVersionError(SynapseError): return cs_error(self.msg, self.errcode, room_version=self._room_version) +class PasswordRefusedError(SynapseError): + """A password has been refused, either during password reset/change or registration. + """ + + def __init__( + self, + msg="This password doesn't comply with the server's policy", + errcode=Codes.WEAK_PASSWORD, + ): + super(PasswordRefusedError, self).__init__( + code=400, msg=msg, errcode=errcode, + ) + + class RequestSendFailed(RuntimeError): """Sending a HTTP request over federation failed due to not being able to talk to the remote server for some reason. @@ -455,7 +503,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs): class FederationError(RuntimeError): - """ This class is used to inform remote home servers about erroneous + """ This class is used to inform remote homeservers about erroneous PDUs they sent us. FATAL: The remote server could not interpret the source event. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 9f06556bd2..8b64d0a285 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +15,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + from six import text_type import jsonschema @@ -20,6 +25,7 @@ from jsonschema import FormatChecker from twisted.internet import defer +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState from synapse.types import RoomID, UserID @@ -66,6 +72,10 @@ ROOM_EVENT_FILTER_SCHEMA = { "contains_url": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"}, + # Include or exclude events with the provided labels. + # cf https://github.com/matrix-org/matrix-doc/pull/2326 + "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, + "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, }, } @@ -259,6 +269,9 @@ class Filter(object): self.contains_url = self.filter_json.get("contains_url", None) + self.labels = self.filter_json.get("org.matrix.labels", None) + self.not_labels = self.filter_json.get("org.matrix.not_labels", []) + def filters_all_types(self): return "*" in self.not_types @@ -282,6 +295,7 @@ class Filter(object): room_id = None ev_type = "m.presence" contains_url = False + labels = [] # type: List[str] else: sender = event.get("sender", None) if not sender: @@ -300,10 +314,11 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) + labels = content.get(EventContentFields.LABELS, []) - return self.check_fields(room_id, sender, ev_type, contains_url) + return self.check_fields(room_id, sender, ev_type, labels, contains_url) - def check_fields(self, room_id, sender, event_type, contains_url): + def check_fields(self, room_id, sender, event_type, labels, contains_url): """Checks whether the filter matches the given event fields. Returns: @@ -313,6 +328,7 @@ class Filter(object): "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(event_type, v), + "labels": lambda v: v in labels, } for name, match_func in literal_keys.items(): diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 172841f595..ec6b3a69a2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,76 +13,161 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections +from collections import OrderedDict +from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.util import Clock class Ratelimiter(object): """ - Ratelimit message sending by user. + Ratelimit actions marked by arbitrary keys. + + Args: + clock: A homeserver clock, for retrieving the current time + rate_hz: The long term number of actions that can be performed in a second. + burst_count: How many actions that can be performed before being limited. """ - def __init__(self): - self.message_counts = collections.OrderedDict() + def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + self.clock = clock + self.rate_hz = rate_hz + self.burst_count = burst_count + + # A ordered dictionary keeping track of actions, when they were last + # performed and how often. Each entry is a mapping from a key of arbitrary type + # to a tuple representing: + # * How many times an action has occurred since a point in time + # * The point in time + # * The rate_hz of this particular entry. This can vary per request + self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] - def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + def can_do_action( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now. - rate_hz: The long term number of messages a user can send in a - second. - burst_count: How many messages the user can send before being - limited. - update (bool): Whether to update the message rates or not. This is - useful to check if a message would be allowed to be sent before - its ready to be actually sent. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + Returns: - A pair of a bool indicating if they can send a message now and a - time in seconds of when they can next send a message. + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero """ - self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.get( - key, (0.0, time_now_s, None) - ) + # Override default values if set + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count + + # Remove any expired entries + self._prune_message_counts(time_now_s) + + # Check if there is an existing count entry for this key + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + + # Check whether performing another action is allowed time_delta = time_now_s - time_start - sent_count = message_count - time_delta * rate_hz - if sent_count < 0: + performed_count = action_count - time_delta * rate_hz + if performed_count < 0: + # Allow, reset back to count 1 allowed = True time_start = time_now_s - message_count = 1.0 - elif sent_count > burst_count - 1.0: + action_count = 1.0 + elif performed_count > burst_count - 1.0: + # Deny, we have exceeded our burst count allowed = False else: + # We haven't reached our limit yet allowed = True - message_count += 1 + action_count += 1.0 if update: - self.message_counts[key] = (message_count, time_start, rate_hz) + self.actions[key] = (action_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = time_start + (message_count - burst_count + 1) / rate_hz + # Find out when the count of existing actions expires + time_allowed = time_start + (action_count - burst_count + 1) / rate_hz + + # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s + else: + # XXX: Why is this -1? This seems to only be used in + # self.ratelimit. I guess so that clients get a time in the past and don't + # feel afraid to try again immediately time_allowed = -1 return allowed, time_allowed - def prune_message_counts(self, time_now_s): - for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = self.message_counts[key] + def _prune_message_counts(self, time_now_s: int): + """Remove message count entries that have not exceeded their defined + rate_hz limit + + Args: + time_now_s: The current time + """ + # We create a copy of the key list here as the dictionary is modified during + # the loop + for key in list(self.actions.keys()): + action_count, time_start, rate_hz = self.actions[key] + + # Rate limit = "seconds since we started limiting this action" * rate_hz + # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start - if message_count - time_delta * rate_hz > 0: - break + if action_count - time_delta * rate_hz > 0: + continue else: - del self.message_counts[key] + del self.actions[key] + + def ratelimit( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ): + """Checks if an action can be performed. If not, raises a LimitExceededError + + Args: + key: An arbitrary key used to classify an action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Raises: + LimitExceededError: If an action could not be performed, along with the time in + milliseconds until the action can be performed again + """ + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True): allowed, time_allowed = self.can_do_action( - key, time_now_s, rate_hz, burst_count, update + key, + rate_hz=rate_hz, + burst_count=burst_count, + update=update, + _time_now_s=time_now_s, ) if not allowed: diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 95292b7dec..d7baf2bc39 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import Dict + import attr @@ -54,6 +57,17 @@ class RoomVersion(object): state_res = attr.ib() # int; one of the StateResolutionVersions enforce_key_validity = attr.ib() # bool + # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules + special_case_aliases_auth = attr.ib(type=bool) + # Strictly enforce canonicaljson, do not allow: + # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] + # * Floats + # * NaN, Infinity, -Infinity + strict_canonicaljson = attr.ib(type=bool) + # bool: MSC2209: Check 'notifications' key while verifying + # m.room.power_levels auth rules. + limit_notifications_power_levels = attr.ib(type=bool) + class RoomVersions(object): V1 = RoomVersion( @@ -62,6 +76,9 @@ class RoomVersions(object): EventFormatVersions.V1, StateResolutionVersions.V1, enforce_key_validity=False, + special_case_aliases_auth=True, + strict_canonicaljson=False, + limit_notifications_power_levels=False, ) V2 = RoomVersion( "2", @@ -69,6 +86,9 @@ class RoomVersions(object): EventFormatVersions.V1, StateResolutionVersions.V2, enforce_key_validity=False, + special_case_aliases_auth=True, + strict_canonicaljson=False, + limit_notifications_power_levels=False, ) V3 = RoomVersion( "3", @@ -76,6 +96,9 @@ class RoomVersions(object): EventFormatVersions.V2, StateResolutionVersions.V2, enforce_key_validity=False, + special_case_aliases_auth=True, + strict_canonicaljson=False, + limit_notifications_power_levels=False, ) V4 = RoomVersion( "4", @@ -83,6 +106,9 @@ class RoomVersions(object): EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=False, + special_case_aliases_auth=True, + strict_canonicaljson=False, + limit_notifications_power_levels=False, ) V5 = RoomVersion( "5", @@ -90,6 +116,19 @@ class RoomVersions(object): EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, + special_case_aliases_auth=True, + strict_canonicaljson=False, + limit_notifications_power_levels=False, + ) + V6 = RoomVersion( + "6", + RoomDisposition.STABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, ) @@ -101,5 +140,6 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V3, RoomVersions.V4, RoomVersions.V5, + RoomVersions.V6, ) -} # type: dict[str, RoomVersion] +} # type: Dict[str, RoomVersion] diff --git a/synapse/api/urls.py b/synapse/api/urls.py index ff1f39e86c..f34434bd67 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -29,7 +29,6 @@ FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" -CONTENT_REPO_PREFIX = "/_matrix/content" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index d877c77834..a01bac2997 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses): bind_addresses (list): Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: - logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") + logger.warning( + "Failed to listen on 0.0.0.0, continuing because listening on [::]" + ) else: raise e diff --git a/synapse/app/_base.py b/synapse/app/_base.py index c30fdeee9a..dedff81af3 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,6 +22,7 @@ import sys import traceback from daemonize import Daemonize +from typing_extensions import NoReturn from twisted.internet import defer, error, reactor from twisted.protocols.tls import TLSMemoryBIOFactory @@ -139,9 +140,9 @@ def start_reactor( run() -def quit_with_error(error_string): +def quit_with_error(error_string: str) -> NoReturn: message_lines = error_string.split("\n") - line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 + line_length = max(len(line) for line in message_lines if len(line) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) @@ -237,6 +238,12 @@ def start(hs, listeners=None): """ Start a Synapse server or worker. + Should be called once the reactor is running and (if we're using ACME) the + TLS certificates are in place. + + Will start the main HTTP listeners and do some other startup tasks, and then + notify systemd. + Args: hs (synapse.server.HomeServer) listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml) @@ -263,14 +270,26 @@ def start(hs, listeners=None): refresh_certificate(hs) # Start the tracer - synapse.logging.opentracing.init_tracer(hs.config) + synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa + hs + ) # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().start_profiling() + hs.get_datastore().db.start_profiling() + hs.get_pusherpool().start() setup_sentry(hs) setup_sdnotify(hs) + + # We now freeze all allocated objects in the hopes that (almost) + # everything currently allocated are things that will be used for the + # rest of time. Doing so means less work each GC (hopefully). + # + # This only works on Python 3.7 + if sys.version_info >= (3, 7): + gc.collect() + gc.freeze() except Exception: traceback.print_exc(file=sys.stderr) reactor = hs.get_reactor() @@ -298,7 +317,7 @@ def setup_sentry(hs): scope.set_tag("matrix_server_name", hs.config.server_name) app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver" - name = hs.config.worker_name if hs.config.worker_name else "master" + name = hs.get_instance_name() scope.set_tag("worker_app", app) scope.set_tag("worker_name", name) @@ -309,9 +328,7 @@ def setup_sdnotify(hs): # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. - hs.get_reactor().addSystemEventTrigger( - "after", "startup", sdnotify, b"READY=1\nMAINPID=%i" % (os.getpid(),) - ) + sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),)) hs.get_reactor().addSystemEventTrigger( "before", "shutdown", sdnotify, b"STOPPING=1" diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 04751a6a5e..a37818fe9a 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -43,9 +43,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -80,18 +78,6 @@ class AdminCmdServer(HomeServer): def start_listening(self, listeners): pass - def build_tcp_replication(self): - return AdminCmdReplicationHandler(self) - - -class AdminCmdReplicationHandler(ReplicationClientHandler): - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - pass - - def get_streams_to_replicate(self): - return {} - @defer.inlineCallbacks def export_data_command(hs, args): @@ -105,8 +91,10 @@ def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = yield hs.get_handlers().admin_handler.export_user_data( - user_id, FileExfiltrationWriter(user_id, directory=directory) + res = yield defer.ensureDeferred( + hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) + ) ) print(res) @@ -229,14 +217,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = AdminCmdServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 767b87d2db..add43147b3 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -13,167 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.appservice") - - -class AppserviceSlaveStore( - DirectoryStore, - SlavedEventStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, -): - pass - - -class AppserviceServer(HomeServer): - DATASTORE_CLASS = AppserviceSlaveStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse appservice now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ASReplicationHandler(self) - -class ASReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(ASReplicationHandler, self).__init__(hs.get_datastore()) - self.appservice_handler = hs.get_application_service_handler() - - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) - - if stream_name == "events": - max_stream_id = self.store.get_room_max_stream_ordering() - run_in_background(self._notify_app_services, max_stream_id) - - @defer.inlineCallbacks - def _notify_app_services(self, room_stream_id): - try: - yield self.appservice_handler.notify_interested_services(room_stream_id) - except Exception: - logger.exception("Error notifying application services of event") - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse appservice", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.appservice" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - if config.notify_appservices: - sys.stderr.write( - "\nThe appservices must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``notify_appservices: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.notify_appservices = True - - ps = AppserviceServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ps, config, use_worker_options=True) - - ps.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ps, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-appservice", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dbcc414c42..add43147b3 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -13,193 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore -from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.profile import SlavedProfileStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v1.login import LoginRestServlet -from synapse.rest.client.v1.push_rule import PushRuleRestServlet -from synapse.rest.client.v1.room import ( - JoinedRoomMemberListRestServlet, - PublicRoomListRestServlet, - RoomEventContextServlet, - RoomMemberListRestServlet, - RoomMessageListRestServlet, - RoomStateRestServlet, -) -from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha.account import ThreepidRestServlet -from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet -from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.rest.client.versions import VersionsRestServlet -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.client_reader") - - -class ClientReaderSlavedStore( - SlavedDeviceInboxStore, - SlavedDeviceStore, - SlavedReceiptsStore, - SlavedPushRuleStore, - SlavedGroupServerStore, - SlavedAccountDataStore, - SlavedEventStore, - SlavedKeyStore, - RoomStore, - DirectoryStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedTransactionStore, - SlavedProfileStore, - SlavedClientIpStore, - BaseSlavedStore, -): - pass - - -class ClientReaderServer(HomeServer): - DATASTORE_CLASS = ClientReaderSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - - PublicRoomListRestServlet(self).register(resource) - RoomMemberListRestServlet(self).register(resource) - JoinedRoomMemberListRestServlet(self).register(resource) - RoomStateRestServlet(self).register(resource) - RoomEventContextServlet(self).register(resource) - RoomMessageListRestServlet(self).register(resource) - RegisterRestServlet(self).register(resource) - LoginRestServlet(self).register(resource) - ThreepidRestServlet(self).register(resource) - KeyQueryServlet(self).register(resource) - KeyChangesServlet(self).register(resource) - VoipRestServlet(self).register(resource) - PushRuleRestServlet(self).register(resource) - VersionsRestServlet(self).register(resource) - - resources.update({"/_matrix/client": resource}) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - logger.info("Synapse client reader now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse client reader", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.client_reader" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - ss = ClientReaderServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-client-reader", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index c67fe69a50..e9c098c4e7 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -13,192 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.profile import SlavedProfileStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v1.profile import ( - ProfileAvatarURLRestServlet, - ProfileDisplaynameRestServlet, - ProfileRestServlet, -) -from synapse.rest.client.v1.room import ( - JoinRoomAliasServlet, - RoomMembershipRestServlet, - RoomSendEventRestServlet, - RoomStateEventRestServlet, -) -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.storage.user_directory import UserDirectoryStore -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.event_creator") - - -class EventCreatorSlavedStore( - # FIXME(#3714): We need to add UserDirectoryStore as we write directly - # rather than going via the correct worker. - UserDirectoryStore, - DirectoryStore, - SlavedTransactionStore, - SlavedProfileStore, - SlavedAccountDataStore, - SlavedPusherStore, - SlavedReceiptsStore, - SlavedPushRuleStore, - SlavedDeviceStore, - SlavedClientIpStore, - SlavedApplicationServiceStore, - SlavedEventStore, - SlavedRegistrationStore, - RoomStore, - BaseSlavedStore, -): - pass - - -class EventCreatorServer(HomeServer): - DATASTORE_CLASS = EventCreatorSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - RoomSendEventRestServlet(self).register(resource) - RoomMembershipRestServlet(self).register(resource) - RoomStateEventRestServlet(self).register(resource) - JoinRoomAliasServlet(self).register(resource) - ProfileAvatarURLRestServlet(self).register(resource) - ProfileDisplaynameRestServlet(self).register(resource) - ProfileRestServlet(self).register(resource) - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse event creator now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse event creator", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.event_creator" - - assert config.worker_replication_http_port is not None - - # This should only be done on the user directory worker or the master - config.update_user_directory = False - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - ss = EventCreatorServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-event-creator", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 1ef027a88c..add43147b3 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -13,174 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.api.urls import FEDERATION_PREFIX, SERVER_KEY_V2_PREFIX -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.federation.transport.server import TransportLayerServer -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.directory import DirectoryStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.keys import SlavedKeyStore -from synapse.replication.slave.storage.profile import SlavedProfileStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.federation_reader") - - -class FederationReaderSlavedStore( - SlavedAccountDataStore, - SlavedProfileStore, - SlavedApplicationServiceStore, - SlavedPusherStore, - SlavedPushRuleStore, - SlavedReceiptsStore, - SlavedEventStore, - SlavedKeyStore, - SlavedRegistrationStore, - RoomStore, - DirectoryStore, - SlavedTransactionStore, - BaseSlavedStore, -): - pass - - -class FederationReaderServer(HomeServer): - DATASTORE_CLASS = FederationReaderSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "federation": - resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) - if name == "openid" and "federation" not in res["names"]: - # Only load the openid resource separately if federation resource - # is not specified since federation resource includes openid - # resource. - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } - ) - - if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - reactor=self.get_reactor(), - ) - - logger.info("Synapse federation reader now listening on port %d", port) - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config( - "Synapse federation reader", config_options - ) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.federation_reader" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - ss = FederationReaderServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-federation-reader", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 04fbb407af..add43147b3 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -13,285 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.federation import send_queue -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams._base import ReceiptsStream -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.types import ReadReceipt -from synapse.util.async_helpers import Linearizer -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.federation_sender") - - -class FederationSenderSlaveStore( - SlavedDeviceInboxStore, - SlavedTransactionStore, - SlavedReceiptsStore, - SlavedEventStore, - SlavedRegistrationStore, - SlavedDeviceStore, - SlavedPresenceStore, -): - def __init__(self, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(db_conn, hs) - - # We pull out the current federation stream position now so that we - # always have a known value for the federation position in memory so - # that we don't have to bounce via a deferred once when we start the - # replication streams. - self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) - - def _get_federation_out_pos(self, db_conn): - sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?" - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, ("federation",)) - rows = txn.fetchall() - txn.close() - - return rows[0][0] if rows else -1 - - -class FederationSenderServer(HomeServer): - DATASTORE_CLASS = FederationSenderSlaveStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse federation_sender now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return FederationSenderReplicationHandler(self) - - -class FederationSenderReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) - self.send_handler = FederationSenderHandler(hs, self) - - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(FederationSenderReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - self.send_handler.process_replication_rows(stream_name, token, rows) - - def get_streams_to_replicate(self): - args = super( - FederationSenderReplicationHandler, self - ).get_streams_to_replicate() - args.update(self.send_handler.stream_positions()) - return args - - -def start(config_options): - try: - config = HomeServerConfig.load_config( - "Synapse federation sender", config_options - ) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - assert config.worker_app == "synapse.app.federation_sender" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - if config.send_federation: - sys.stderr.write( - "\nThe send_federation must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``send_federation: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.send_federation = True - - ss = FederationSenderServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-federation-sender", config) - - -class FederationSenderHandler(object): - """Processes the replication stream and forwards the appropriate entries - to the federation sender. - """ - - def __init__(self, hs, replication_client): - self.store = hs.get_datastore() - self._is_mine_id = hs.is_mine_id - self.federation_sender = hs.get_federation_sender() - self.replication_client = replication_client - - self.federation_position = self.store.federation_out_pos_startup - self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - - self._last_ack = self.federation_position - - self._room_serials = {} - self._room_typing = {} - - def on_start(self): - # There may be some events that are persisted but haven't been sent, - # so send them now. - self.federation_sender.notify_new_events( - self.store.get_room_max_stream_ordering() - ) - - def stream_positions(self): - return {"federation": self.federation_position} - - def process_replication_rows(self, stream_name, token, rows): - # The federation stream contains things that we want to send out, e.g. - # presence, typing, etc. - if stream_name == "federation": - send_queue.process_rows_for_federation(self.federation_sender, rows) - run_in_background(self.update_token, token) - - # We also need to poke the federation sender when new events happen - elif stream_name == "events": - self.federation_sender.notify_new_events(token) - - # ... and when new receipts happen - elif stream_name == ReceiptsStream.NAME: - run_as_background_process( - "process_receipts_for_federation", self._on_new_receipts, rows - ) - - @defer.inlineCallbacks - def _on_new_receipts(self, rows): - """ - Args: - rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]): - new receipts to be processed - """ - for receipt in rows: - # we only want to send on receipts for our own users - if not self._is_mine_id(receipt.user_id): - continue - receipt_info = ReadReceipt( - receipt.room_id, - receipt.receipt_type, - receipt.user_id, - [receipt.event_id], - receipt.data, - ) - yield self.federation_sender.send_read_receipt(receipt_info) - - @defer.inlineCallbacks - def update_token(self, token): - try: - self.federation_position = token - - # We linearize here to ensure we don't have races updating the token - with (yield self._fed_position_linearizer.queue(None)): - if self._last_ack < self.federation_position: - yield self.store.update_federation_out_pos( - "federation", self.federation_position - ) - - # We ACK this token over replication so that the master can drop - # its in memory queues - self.replication_client.send_federation_ack( - self.federation_position - ) - self._last_ack = self.federation_position - except Exception: - logger.exception("Error updating federation stream position") +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 9504bfbc70..add43147b3 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -13,246 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.api.errors import HttpResponseException, SynapseError -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.frontend_proxy") - - -class PresenceStatusStubServlet(RestServlet): - PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status") - - def __init__(self, hs): - super(PresenceStatusStubServlet, self).__init__() - self.http_client = hs.get_simple_http_client() - self.auth = hs.get_auth() - self.main_uri = hs.config.worker_main_http_uri - - @defer.inlineCallbacks - def on_GET(self, request, user_id): - # Pass through the auth headers, if any, in case the access token - # is there. - auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) - headers = {"Authorization": auth_headers} - - try: - result = yield self.http_client.get_json( - self.main_uri + request.uri.decode("ascii"), headers=headers - ) - except HttpResponseException as e: - raise e.to_synapse_error() - - return 200, result - - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - yield self.auth.get_user_by_req(request) - return 200, {} - - -class KeyUploadServlet(RestServlet): - PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super(KeyUploadServlet, self).__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.http_client = hs.get_simple_http_client() - self.main_uri = hs.config.worker_main_http_uri - - @defer.inlineCallbacks - def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - if device_id is not None: - # passing the device_id here is deprecated; however, we allow it - # for now for compatibility with older clients. - if requester.device_id is not None and device_id != requester.device_id: - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) - else: - device_id = requester.device_id - - if device_id is None: - raise SynapseError( - 400, "To upload keys, you must pass device_id when authenticating" - ) - - if body: - # They're actually trying to upload something, proxy to main synapse. - # Pass through the auth headers, if any, in case the access token - # is there. - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) - headers = {"Authorization": auth_headers} - result = yield self.http_client.post_json_get_json( - self.main_uri + request.uri.decode("ascii"), body, headers=headers - ) - - return 200, result - else: - # Just interested in counts. - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - return 200, {"one_time_key_counts": result} - - -class FrontendProxySlavedStore( - SlavedDeviceStore, - SlavedClientIpStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - BaseSlavedStore, -): - pass +import sys -class FrontendProxyServer(HomeServer): - DATASTORE_CLASS = FrontendProxySlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - KeyUploadServlet(self).register(resource) - - # If presence is disabled, use the stub servlet that does - # not allow sending presence - if not self.config.use_presence: - PresenceStatusStubServlet(self).register(resource) - - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - reactor=self.get_reactor(), - ) - - logger.info("Synapse client reader now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse frontend proxy", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.frontend_proxy" - - assert config.worker_main_http_uri is not None - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - ss = FrontendProxyServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-frontend-proxy", config) - +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py new file mode 100644 index 0000000000..f3ec2a34ec --- /dev/null +++ b/synapse/app/generic_worker.py @@ -0,0 +1,1020 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import sys +from typing import Dict, Iterable, Optional, Set + +from typing_extensions import ContextManager + +from twisted.internet import defer, reactor + +import synapse +import synapse.events +from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError +from synapse.api.urls import ( + CLIENT_API_PREFIX, + FEDERATION_PREFIX, + LEGACY_MEDIA_PREFIX, + MEDIA_PREFIX, + SERVER_KEY_V2_PREFIX, +) +from synapse.app import _base +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.federation import send_queue +from synapse.federation.transport.server import TransportLayerServer +from synapse.handlers.presence import ( + BasePresenceHandler, + PresenceState, + get_interested_parties, +) +from synapse.http.server import JsonResource, OptionsResource +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseSite +from synapse.logging.context import LoggingContext +from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource +from synapse.replication.http.presence import ( + ReplicationBumpPresenceActiveTime, + ReplicationPresenceSetState, +) +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.client_ips import SlavedClientIpStore +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore +from synapse.replication.slave.storage.devices import SlavedDeviceStore +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.filtering import SlavedFilteringStore +from synapse.replication.slave.storage.groups import SlavedGroupServerStore +from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.profile import SlavedProfileStore +from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore +from synapse.replication.slave.storage.pushers import SlavedPusherStore +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.transactions import SlavedTransactionStore +from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.commands import ClearUserSyncsCommand +from synapse.replication.tcp.streams import ( + AccountDataStream, + DeviceListsStream, + GroupServerStream, + PresenceStream, + PushersStream, + PushRulesStream, + ReceiptsStream, + TagAccountDataStream, + ToDeviceStream, + TypingStream, +) +from synapse.rest.admin import register_servlets_for_media_repo +from synapse.rest.client.v1 import events +from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet +from synapse.rest.client.v1.login import LoginRestServlet +from synapse.rest.client.v1.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) +from synapse.rest.client.v1.push_rule import PushRuleRestServlet +from synapse.rest.client.v1.room import ( + JoinedRoomMemberListRestServlet, + JoinRoomAliasServlet, + PublicRoomListRestServlet, + RoomEventContextServlet, + RoomInitialSyncRestServlet, + RoomMemberListRestServlet, + RoomMembershipRestServlet, + RoomMessageListRestServlet, + RoomSendEventRestServlet, + RoomStateEventRestServlet, + RoomStateRestServlet, +) +from synapse.rest.client.v1.voip import VoipRestServlet +from synapse.rest.client.v2_alpha import groups, sync, user_directory +from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client.v2_alpha.account import ThreepidRestServlet +from synapse.rest.client.v2_alpha.account_data import ( + AccountDataServlet, + RoomAccountDataServlet, +) +from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet +from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.server import HomeServer +from synapse.storage.data_stores.main.censor_events import CensorEventsStore +from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore +from synapse.storage.data_stores.main.monthly_active_users import ( + MonthlyActiveUsersWorkerStore, +) +from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.search import SearchWorkerStore +from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore +from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.types import ReadReceipt +from synapse.util.async_helpers import Linearizer +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.manhole import manhole +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger("synapse.app.generic_worker") + + +class PresenceStatusStubServlet(RestServlet): + """If presence is disabled this servlet can be used to stub out setting + presence status. + """ + + PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status") + + def __init__(self, hs): + super(PresenceStatusStubServlet, self).__init__() + self.auth = hs.get_auth() + + async def on_GET(self, request, user_id): + await self.auth.get_user_by_req(request) + return 200, {"presence": "offline"} + + async def on_PUT(self, request, user_id): + await self.auth.get_user_by_req(request) + return 200, {} + + +class KeyUploadServlet(RestServlet): + """An implementation of the `KeyUploadServlet` that responds to read only + requests, but otherwise proxies through to the master instance. + """ + + PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(KeyUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.http_client = hs.get_simple_http_client() + self.main_uri = hs.config.worker_main_http_uri + + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + if device_id is not None: + # passing the device_id here is deprecated; however, we allow it + # for now for compatibility with older clients. + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) + else: + device_id = requester.device_id + + if device_id is None: + raise SynapseError( + 400, "To upload keys, you must pass device_id when authenticating" + ) + + if body: + # They're actually trying to upload something, proxy to main synapse. + # Pass through the auth headers, if any, in case the access token + # is there. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) + headers = {"Authorization": auth_headers} + try: + result = await self.http_client.post_json_get_json( + self.main_uri + request.uri.decode("ascii"), body, headers=headers + ) + except HttpResponseException as e: + raise e.to_synapse_error() from e + except RequestSendFailed as e: + raise SynapseError(502, "Failed to talk to master") from e + + return 200, result + else: + # Just interested in counts. + result = await self.store.count_e2e_one_time_keys(user_id, device_id) + return 200, {"one_time_key_counts": result} + + +class _NullContextManager(ContextManager[None]): + """A context manager which does nothing.""" + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +UPDATE_SYNCING_USERS_MS = 10 * 1000 + + +class GenericWorkerPresence(BasePresenceHandler): + def __init__(self, hs): + super().__init__(hs) + self.hs = hs + self.is_mine_id = hs.is_mine_id + self.http_client = hs.get_simple_http_client() + + self._presence_enabled = hs.config.use_presence + + # The number of ongoing syncs on this process, by user id. + # Empty if _presence_enabled is false. + self._user_to_num_current_syncs = {} # type: Dict[str, int] + + self.notifier = hs.get_notifier() + self.instance_id = hs.get_instance_id() + + # user_id -> last_sync_ms. Lists the users that have stopped syncing + # but we haven't notified the master of that yet + self.users_going_offline = {} + + self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) + self._set_state_client = ReplicationPresenceSetState.make_client(hs) + + self._send_stop_syncing_loop = self.clock.looping_call( + self.send_stop_syncing, UPDATE_SYNCING_USERS_MS + ) + + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + run_as_background_process, + "generic_presence.on_shutdown", + self._on_shutdown, + ) + + def _on_shutdown(self): + if self._presence_enabled: + self.hs.get_tcp_replication().send_command( + ClearUserSyncsCommand(self.instance_id) + ) + + def send_user_sync(self, user_id, is_syncing, last_sync_ms): + if self._presence_enabled: + self.hs.get_tcp_replication().send_user_sync( + self.instance_id, user_id, is_syncing, last_sync_ms + ) + + def mark_as_coming_online(self, user_id): + """A user has started syncing. Send a UserSync to the master, unless they + had recently stopped syncing. + + Args: + user_id (str) + """ + going_offline = self.users_going_offline.pop(user_id, None) + if not going_offline: + # Safe to skip because we haven't yet told the master they were offline + self.send_user_sync(user_id, True, self.clock.time_msec()) + + def mark_as_going_offline(self, user_id): + """A user has stopped syncing. We wait before notifying the master as + its likely they'll come back soon. This allows us to avoid sending + a stopped syncing immediately followed by a started syncing notification + to the master + + Args: + user_id (str) + """ + self.users_going_offline[user_id] = self.clock.time_msec() + + def send_stop_syncing(self): + """Check if there are any users who have stopped syncing a while ago + and haven't come back yet. If there are poke the master about them. + """ + now = self.clock.time_msec() + for user_id, last_sync_ms in list(self.users_going_offline.items()): + if now - last_sync_ms > UPDATE_SYNCING_USERS_MS: + self.users_going_offline.pop(user_id, None) + self.send_user_sync(user_id, False, last_sync_ms) + + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Record that a user is syncing. + + Called by the sync and events servlets to record that a user has connected to + this worker and is waiting for some events. + """ + if not affect_presence or not self._presence_enabled: + return _NullContextManager() + + curr_sync = self._user_to_num_current_syncs.get(user_id, 0) + self._user_to_num_current_syncs[user_id] = curr_sync + 1 + + # If we went from no in flight sync to some, notify replication + if self._user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) + + def _end(): + # We check that the user_id is in user_to_num_current_syncs because + # user_to_num_current_syncs may have been cleared if we are + # shutting down. + if user_id in self._user_to_num_current_syncs: + self._user_to_num_current_syncs[user_id] -= 1 + + # If we went from one in flight sync to non, notify replication + if self._user_to_num_current_syncs[user_id] == 0: + self.mark_as_going_offline(user_id) + + @contextlib.contextmanager + def _user_syncing(): + try: + yield + finally: + _end() + + return _user_syncing() + + @defer.inlineCallbacks + def notify_from_replication(self, states, stream_id): + parties = yield get_interested_parties(self.store, states) + room_ids_to_states, users_to_states = parties + + self.notifier.on_new_event( + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=users_to_states.keys(), + ) + + @defer.inlineCallbacks + def process_replication_rows(self, token, rows): + states = [ + UserPresenceState( + row.user_id, + row.state, + row.last_active_ts, + row.last_federation_update_ts, + row.last_user_sync_ts, + row.status_msg, + row.currently_active, + ) + for row in rows + ] + + for state in states: + self.user_to_current_state[state.user_id] = state + + stream_id = token + yield self.notify_from_replication(states, stream_id) + + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + return [ + user_id + for user_id, count in self._user_to_num_current_syncs.items() + if count > 0 + ] + + async def set_state(self, target_user, state, ignore_status_msg=False): + """Set the presence state of the user. + """ + presence = state["presence"] + + valid_presence = ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + ) + if presence not in valid_presence: + raise SynapseError(400, "Invalid presence state") + + user_id = target_user.to_string() + + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + await self._set_state_client( + user_id=user_id, state=state, ignore_status_msg=ignore_status_msg + ) + + async def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. + """ + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + user_id = user.to_string() + await self._bump_active_client(user_id=user_id) + + +class GenericWorkerTyping(object): + def __init__(self, hs): + self._latest_room_serial = 0 + self._reset() + + def _reset(self): + """ + Reset the typing handler's data caches. + """ + # map room IDs to serial numbers + self._room_serials = {} + # map room IDs to sets of users currently typing + self._room_typing = {} + + def process_replication_rows(self, token, rows): + if self._latest_room_serial > token: + # The master has gone backwards. To prevent inconsistent data, just + # clear everything. + self._reset() + + # Set the latest serial token to whatever the server gave us. + self._latest_room_serial = token + + for row in rows: + self._room_serials[row.room_id] = token + self._room_typing[row.room_id] = row.user_ids + + def get_current_token(self) -> int: + return self._latest_room_serial + + +class GenericWorkerSlavedStore( + # FIXME(#3714): We need to add UserDirectoryStore as we write directly + # rather than going via the correct worker. + UserDirectoryStore, + UIAuthWorkerStore, + SlavedDeviceInboxStore, + SlavedDeviceStore, + SlavedReceiptsStore, + SlavedPushRuleStore, + SlavedGroupServerStore, + SlavedAccountDataStore, + SlavedPusherStore, + CensorEventsStore, + SlavedEventStore, + SlavedKeyStore, + RoomStore, + DirectoryStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + SlavedTransactionStore, + SlavedProfileStore, + SlavedClientIpStore, + SlavedPresenceStore, + SlavedFilteringStore, + MonthlyActiveUsersWorkerStore, + MediaRepositoryStore, + SearchWorkerStore, + BaseSlavedStore, +): + def __init__(self, database, db_conn, hs): + super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs) + + # We pull out the current federation stream position now so that we + # always have a known value for the federation position in memory so + # that we don't have to bounce via a deferred once when we start the + # replication streams. + self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) + + def _get_federation_out_pos(self, db_conn): + sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, ("federation",)) + rows = txn.fetchall() + txn.close() + + return rows[0][0] if rows else -1 + + +class GenericWorkerServer(HomeServer): + DATASTORE_CLASS = GenericWorkerSlavedStore + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_addresses = listener_config["bind_addresses"] + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + + PublicRoomListRestServlet(self).register(resource) + RoomMemberListRestServlet(self).register(resource) + JoinedRoomMemberListRestServlet(self).register(resource) + RoomStateRestServlet(self).register(resource) + RoomEventContextServlet(self).register(resource) + RoomMessageListRestServlet(self).register(resource) + RegisterRestServlet(self).register(resource) + LoginRestServlet(self).register(resource) + ThreepidRestServlet(self).register(resource) + KeyQueryServlet(self).register(resource) + KeyChangesServlet(self).register(resource) + VoipRestServlet(self).register(resource) + PushRuleRestServlet(self).register(resource) + VersionsRestServlet(self).register(resource) + RoomSendEventRestServlet(self).register(resource) + RoomMembershipRestServlet(self).register(resource) + RoomStateEventRestServlet(self).register(resource) + JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) + ProfileDisplaynameRestServlet(self).register(resource) + ProfileRestServlet(self).register(resource) + KeyUploadServlet(self).register(resource) + AccountDataServlet(self).register(resource) + RoomAccountDataServlet(self).register(resource) + + sync.register_servlets(self, resource) + events.register_servlets(self, resource) + InitialSyncRestServlet(self).register(resource) + RoomInitialSyncRestServlet(self).register(resource) + + user_directory.register_servlets(self, resource) + + # If presence is disabled, use the stub servlet that does + # not allow sending presence + if not self.config.use_presence: + PresenceStatusStubServlet(self).register(resource) + + groups.register_servlets(self, resource) + + resources.update({CLIENT_API_PREFIX: resource}) + elif name == "federation": + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) + elif name == "media": + if self.config.can_load_media_repo: + media_repo = self.get_media_repository_resource() + + # We need to serve the admin servlets for media on the + # worker. + admin_resource = JsonResource(self, canonical_json=False) + register_servlets_for_media_repo(self, admin_resource) + + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + "/_synapse/admin": admin_resource, + } + ) + else: + logger.warning( + "A 'media' listener is configured but the media" + " repository is disabled. Ignoring." + ) + + if name == "openid" and "federation" not in res["names"]: + # Only load the openid resource separately if federation resource + # is not specified since federation resource includes openid + # resource. + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) + + if name in ["keys", "federation"]: + resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationRestResource(self) + + root_resource = create_resource_tree(resources, OptionsResource()) + + _base.listen_tcp( + bind_addresses, + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + self.version_string, + ), + reactor=self.get_reactor(), + ) + + logger.info("Synapse worker now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + _base.listen_tcp( + listener["bind_addresses"], + listener["port"], + manhole( + username="matrix", password="rabbithole", globals={"hs": self} + ), + ) + elif listener["type"] == "metrics": + if not self.get_config().enable_metrics: + logger.warning( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) + else: + _base.listen_metrics(listener["bind_addresses"], listener["port"]) + else: + logger.warning("Unrecognized listener type: %s", listener["type"]) + + self.get_tcp_replication().start_replication(self) + + def remove_pusher(self, app_id, push_key, user_id): + self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) + + def build_replication_data_handler(self): + return GenericWorkerReplicationHandler(self) + + def build_presence_handler(self): + return GenericWorkerPresence(self) + + def build_typing_handler(self): + return GenericWorkerTyping(self) + + +class GenericWorkerReplicationHandler(ReplicationDataHandler): + def __init__(self, hs): + super(GenericWorkerReplicationHandler, self).__init__(hs) + + self.store = hs.get_datastore() + self.typing_handler = hs.get_typing_handler() + self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence + self.notifier = hs.get_notifier() + + self.notify_pushers = hs.config.start_pushers + self.pusher_pool = hs.get_pusherpool() + + self.send_handler = None # type: Optional[FederationSenderHandler] + if hs.config.send_federation: + self.send_handler = FederationSenderHandler(hs) + + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self._process_and_notify(stream_name, instance_name, token, rows) + + async def _process_and_notify(self, stream_name, instance_name, token, rows): + try: + if self.send_handler: + await self.send_handler.process_replication_rows( + stream_name, token, rows + ) + + if stream_name == PushRulesStream.NAME: + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows] + ) + elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): + self.notifier.on_new_event( + "account_data_key", token, users=[row.user_id for row in rows] + ) + elif stream_name == ReceiptsStream.NAME: + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows] + ) + await self.pusher_pool.on_new_receipts( + token, token, {row.room_id for row in rows} + ) + elif stream_name == TypingStream.NAME: + self.typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows] + ) + elif stream_name == ToDeviceStream.NAME: + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event("to_device_key", token, users=entities) + elif stream_name == DeviceListsStream.NAME: + all_room_ids = set() # type: Set[str] + for row in rows: + if row.entity.startswith("@"): + room_ids = await self.store.get_rooms_for_user(row.entity) + all_room_ids.update(room_ids) + self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) + elif stream_name == PresenceStream.NAME: + await self.presence_handler.process_replication_rows(token, rows) + elif stream_name == GroupServerStream.NAME: + self.notifier.on_new_event( + "groups_key", token, users=[row.user_id for row in rows] + ) + elif stream_name == PushersStream.NAME: + for row in rows: + if row.deleted: + self.stop_pusher(row.user_id, row.app_id, row.pushkey) + else: + await self.start_pusher(row.user_id, row.app_id, row.pushkey) + except Exception: + logger.exception("Error processing replication") + + def stop_pusher(self, user_id, app_id, pushkey): + if not self.notify_pushers: + return + + key = "%s:%s" % (app_id, pushkey) + pushers_for_user = self.pusher_pool.pushers.get(user_id, {}) + pusher = pushers_for_user.pop(key, None) + if pusher is None: + return + logger.info("Stopping pusher %r / %r", user_id, key) + pusher.on_stop() + + async def start_pusher(self, user_id, app_id, pushkey): + if not self.notify_pushers: + return + + key = "%s:%s" % (app_id, pushkey) + logger.info("Starting pusher %r / %r", user_id, key) + return await self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + # Let's wake up the transaction queue for the server in case we have + # pending stuff to send to it. + if self.send_handler: + self.send_handler.wake_destination(server) + + +class FederationSenderHandler(object): + """Processes the fedration replication stream + + This class is only instantiate on the worker responsible for sending outbound + federation transactions. It receives rows from the replication stream and forwards + the appropriate entries to the FederationSender class. + """ + + def __init__(self, hs: GenericWorkerServer): + self.store = hs.get_datastore() + self._is_mine_id = hs.is_mine_id + self.federation_sender = hs.get_federation_sender() + self._hs = hs + + # if the worker is restarted, we want to pick up where we left off in + # the replication stream, so load the position from the database. + # + # XXX is this actually worthwhile? Whenever the master is restarted, we'll + # drop some rows anyway (which is mostly fine because we're only dropping + # typing and presence notifications). If the replication stream is + # unreliable, why do we do all this hoop-jumping to store the position in the + # database? See also https://github.com/matrix-org/synapse/issues/7535. + # + self.federation_position = self.store.federation_out_pos_startup + + self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") + self._last_ack = self.federation_position + + def on_start(self): + # There may be some events that are persisted but haven't been sent, + # so send them now. + self.federation_sender.notify_new_events( + self.store.get_room_max_stream_ordering() + ) + + def wake_destination(self, server: str): + self.federation_sender.wake_destination(server) + + async def process_replication_rows(self, stream_name, token, rows): + # The federation stream contains things that we want to send out, e.g. + # presence, typing, etc. + if stream_name == "federation": + send_queue.process_rows_for_federation(self.federation_sender, rows) + await self.update_token(token) + + # We also need to poke the federation sender when new events happen + elif stream_name == "events": + self.federation_sender.notify_new_events(token) + + # ... and when new receipts happen + elif stream_name == ReceiptsStream.NAME: + await self._on_new_receipts(rows) + + # ... as well as device updates and messages + elif stream_name == DeviceListsStream.NAME: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + hosts = {row.entity for row in rows if not row.entity.startswith("@")} + for host in hosts: + self.federation_sender.send_device_messages(host) + + elif stream_name == ToDeviceStream.NAME: + # The to_device stream includes stuff to be pushed to both local + # clients and remote servers, so we ignore entities that start with + # '@' (since they'll be local users rather than destinations). + hosts = {row.entity for row in rows if not row.entity.startswith("@")} + for host in hosts: + self.federation_sender.send_device_messages(host) + + async def _on_new_receipts(self, rows): + """ + Args: + rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): + new receipts to be processed + """ + for receipt in rows: + # we only want to send on receipts for our own users + if not self._is_mine_id(receipt.user_id): + continue + receipt_info = ReadReceipt( + receipt.room_id, + receipt.receipt_type, + receipt.user_id, + [receipt.event_id], + receipt.data, + ) + await self.federation_sender.send_read_receipt(receipt_info) + + async def update_token(self, token): + """Update the record of where we have processed to in the federation stream. + + Called after we have processed a an update received over replication. Sends + a FEDERATION_ACK back to the master, and stores the token that we have processed + in `federation_stream_position` so that we can restart where we left off. + """ + self.federation_position = token + + # We save and send the ACK to master asynchronously, so we don't block + # processing on persistence. We don't need to do this operation for + # every single RDATA we receive, we just need to do it periodically. + + if self._fed_position_linearizer.is_queued(None): + # There is already a task queued up to save and send the token, so + # no need to queue up another task. + return + + run_as_background_process("_save_and_send_ack", self._save_and_send_ack) + + async def _save_and_send_ack(self): + """Save the current federation position in the database and send an ACK + to master with where we're up to. + """ + try: + # We linearize here to ensure we don't have races updating the token + # + # XXX this appears to be redundant, since the ReplicationCommandHandler + # has a linearizer which ensures that we only process one line of + # replication data at a time. Should we remove it, or is it doing useful + # service for robustness? Or could we replace it with an assertion that + # we're not being re-entered? + + with (await self._fed_position_linearizer.queue(None)): + # We persist and ack the same position, so we take a copy of it + # here as otherwise it can get modified from underneath us. + current_position = self.federation_position + + await self.store.update_federation_out_pos( + "federation", current_position + ) + + # We ACK this token over replication so that the master can drop + # its in memory queues + self._hs.get_tcp_replication().send_federation_ack(current_position) + self._last_ack = current_position + except Exception: + logger.exception("Error updating federation stream position") + + +def start(config_options): + try: + config = HomeServerConfig.load_config("Synapse worker", config_options) + except ConfigError as e: + sys.stderr.write("\n" + str(e) + "\n") + sys.exit(1) + + # For backwards compatibility let any of the old app names. + assert config.worker_app in ( + "synapse.app.appservice", + "synapse.app.client_reader", + "synapse.app.event_creator", + "synapse.app.federation_reader", + "synapse.app.federation_sender", + "synapse.app.frontend_proxy", + "synapse.app.generic_worker", + "synapse.app.media_repository", + "synapse.app.pusher", + "synapse.app.synchrotron", + "synapse.app.user_dir", + ) + + if config.worker_app == "synapse.app.appservice": + if config.notify_appservices: + sys.stderr.write( + "\nThe appservices must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``notify_appservices: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the appservice to start since they will be disabled in the main config + config.notify_appservices = True + else: + # For other worker types we force this to off. + config.notify_appservices = False + + if config.worker_app == "synapse.app.pusher": + if config.start_pushers: + sys.stderr.write( + "\nThe pushers must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``start_pushers: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.start_pushers = True + else: + # For other worker types we force this to off. + config.start_pushers = False + + if config.worker_app == "synapse.app.user_dir": + if config.update_user_directory: + sys.stderr.write( + "\nThe update_user_directory must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``update_user_directory: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.update_user_directory = True + else: + # For other worker types we force this to off. + config.update_user_directory = False + + if config.worker_app == "synapse.app.federation_sender": + if config.send_federation: + sys.stderr.write( + "\nThe send_federation must be disabled in the main synapse process" + "\nbefore they can be run in a separate worker." + "\nPlease add ``send_federation: false`` to the main config" + "\n" + ) + sys.exit(1) + + # Force the pushers to start since they will be disabled in the main config + config.send_federation = True + else: + # For other worker types we force this to off. + config.send_federation = False + + synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + + hs = GenericWorkerServer( + config.server_name, + config=config, + version_string="Synapse/" + get_version_string(synapse), + ) + + setup_logging(hs, config, use_worker_options=True) + + hs.setup() + + # Ensure the replication streamer is always started in case we write to any + # streams. Will no-op if no streams can be written to by this worker. + hs.get_replication_streamer() + + reactor.addSystemEventTrigger( + "before", "startup", _base.start, hs, config.worker_listeners + ) + + _base.start_worker_reactor("synapse-generic-worker", config) + + +if __name__ == "__main__": + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 774326dff9..8454d74858 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -19,18 +19,19 @@ from __future__ import print_function import gc import logging +import math import os +import resource import sys from six import iteritems -import psutil from prometheus_client import Gauge from twisted.application import service from twisted.internet import defer, reactor from twisted.python.failure import Failure -from twisted.web.resource import EncodingResourceWrapper, NoResource +from twisted.web.resource import EncodingResourceWrapper, IResource from twisted.web.server import GzipEncoderFactory from twisted.web.static import File @@ -38,7 +39,6 @@ import synapse import synapse.config.logger from synapse import events from synapse.api.urls import ( - CONTENT_REPO_PREFIX, FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, @@ -52,7 +52,11 @@ from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.additional_resource import AdditionalResource -from synapse.http.server import RootRedirect +from synapse.http.server import ( + OptionsResource, + RootOptionsRedirectResource, + RootRedirect, +) from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy @@ -64,13 +68,11 @@ from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer -from synapse.storage import DataStore, are_all_users_on_domain -from synapse.storage.engines import IncorrectDatabaseSetup, create_engine -from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database -from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.storage import DataStore +from synapse.storage.engines import IncorrectDatabaseSetup +from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.module_loader import load_module @@ -110,15 +112,24 @@ class SynapseHomeServer(HomeServer): for path, resmodule in additional_resources.items(): handler_cls, config = load_module(resmodule) handler = handler_cls(config, module_api) - resources[path] = AdditionalResource(self, handler.handle_request) + if IResource.providedBy(handler): + resource = handler + elif hasattr(handler, "handle_request"): + resource = AdditionalResource(self, handler.handle_request) + else: + raise ConfigError( + "additional_resource %s does not implement a known interface" + % (resmodule["module"],) + ) + resources[path] = resource # try to find something useful to redirect '/' to if WEB_CLIENT_PREFIX in resources: - root_resource = RootRedirect(WEB_CLIENT_PREFIX) + root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) elif STATIC_PREFIX in resources: - root_resource = RootRedirect(STATIC_PREFIX) + root_resource = RootOptionsRedirectResource(STATIC_PREFIX) else: - root_resource = NoResource() + root_resource = OptionsResource() root_resource = create_resource_tree(resources, root_resource) @@ -184,6 +195,11 @@ class SynapseHomeServer(HomeServer): } ) + if self.get_config().oidc_enabled: + from synapse.rest.oidc import OIDCResource + + resources["/_synapse/oidc"] = OIDCResource(self) + if self.get_config().saml2_enabled: from synapse.rest.saml2 import SAML2Resource @@ -222,13 +238,7 @@ class SynapseHomeServer(HomeServer): if self.get_config().enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( - { - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - } + {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} ) elif name == "media": raise ConfigError( @@ -239,16 +249,26 @@ class SynapseHomeServer(HomeServer): resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "webclient": - webclient_path = self.get_config().web_client_location + webclient_loc = self.get_config().web_client_location - if webclient_path is None: + if webclient_loc is None: logger.warning( "Not enabling webclient resource, as web_client_location is unset." ) + elif webclient_loc.startswith("http://") or webclient_loc.startswith( + "https://" + ): + resources[WEB_CLIENT_PREFIX] = RootRedirect(webclient_loc) else: + logger.warning( + "Running webclient on the same domain is not recommended: " + "https://github.com/matrix-org/synapse#security-note - " + "after you move webclient to different host you can set " + "web_client_location to its full URL to enable redirection." + ) # GZip is disabled here due to # https://twistedmatrix.com/trac/ticket/7678 - resources[WEB_CLIENT_PREFIX] = File(webclient_path) + resources[WEB_CLIENT_PREFIX] = File(webclient_loc) if name == "metrics" and self.get_config().enable_metrics: resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) @@ -261,6 +281,12 @@ class SynapseHomeServer(HomeServer): def start_listening(self, listeners): config = self.get_config() + if config.redis_enabled: + # If redis is enabled we connect via the replication command handler + # in the same way as the workers (since we're effectively a client + # rather than a server). + self.get_tcp_replication().start_replication(self) + for listener in listeners: if listener["type"] == "http": self._listening_services.extend(self._listener_http(config, listener)) @@ -282,7 +308,7 @@ class SynapseHomeServer(HomeServer): reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -291,27 +317,16 @@ class SynapseHomeServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - def run_startup_checks(self, db_conn, database_engine): - all_users_native = are_all_users_on_domain( - db_conn.cursor(), database_engine, self.hostname - ) - if not all_users_native: - quit_with_error( - "Found users in database not native to %s!\n" - "You cannot changed a synapse server_name after it's been configured" - % (self.hostname,) - ) - - try: - database_engine.check_database(db_conn.cursor()) - except IncorrectDatabaseSetup as e: - quit_with_error(str(e)) + logger.warning("Unrecognized listener type: %s", listener["type"]) # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") +current_mau_by_service_gauge = Gauge( + "synapse_admin_mau_current_mau_by_service", + "Current MAU by service", + ["app_service"], +) max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") registered_reserved_users_mau_gauge = Gauge( "synapse_admin_mau:registered_reserved_users", @@ -333,7 +348,7 @@ def setup(config_options): "Synapse Homeserver", config_options ) except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") + sys.stderr.write("\nERROR: %s\n" % (e,)) sys.exit(1) if not config: @@ -343,40 +358,23 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection - hs = SynapseHomeServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) synapse.config.logger.setup_logging(hs, config, use_worker_options=False) - logger.info("Preparing database: %s...", config.database_config["name"]) + logger.info("Setting up server") try: - with hs.get_db_conn(run_new_connection=False) as db_conn: - prepare_database(db_conn, database_engine, config=config) - database_engine.on_new_connection(db_conn) - - hs.run_startup_checks(db_conn, database_engine) - - db_conn.commit() - except UpgradeDatabaseException: - sys.stderr.write( - "\nFailed to upgrade database.\n" - "Have you checked for version specific instructions in" - " UPGRADES.rst?\n" - ) - sys.exit(1) + hs.setup() + except IncorrectDatabaseSetup as e: + quit_with_error(str(e)) + except UpgradeDatabaseException as e: + quit_with_error("Failed to upgrade database: %s" % (e,)) - logger.info("Database prepared in %s.", config.database_config["name"]) - - hs.setup() hs.setup_master() @defer.inlineCallbacks @@ -432,10 +430,16 @@ def setup(config_options): # Check if it needs to be reprovisioned every day. hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) + # Load the OIDC provider metadatas, if OIDC is enabled. + if hs.config.oidc_enabled: + oidc = hs.get_oidc_handler() + # Loading the provider metadata also ensures the provider config is valid. + yield defer.ensureDeferred(oidc.load_metadata()) + yield defer.ensureDeferred(oidc.load_jwks()) + _base.start(hs, config.listeners) - hs.get_pusherpool().start() - hs.get_datastore().start_doing_background_updates() + hs.get_datastore().db.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) @@ -471,6 +475,93 @@ class SynapseService(service.Service): return self._port.stopListening() +# Contains the list of processes we will be monitoring +# currently either 0 or 1 +_stats_process = [] + + +@defer.inlineCallbacks +def phone_stats_home(hs, stats, stats_process=_stats_process): + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 + + # + # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # General statistics + # + + stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = yield hs.get_datastore().count_all_users() + + total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = yield hs.get_datastore().count_daily_user_type() + for name, count in iteritems(daily_user_type_results): + stats["daily_user_type_" + name] = count + + room_count = yield hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + + r30_results = yield hs.get_datastore().count_r30_users() + for name, count in iteritems(r30_results): + stats["r30_users_" + name] = count + + daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size + + # + # Database version + # + + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db.engine.server_version + + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + try: + yield hs.get_proxied_http_client().put_json( + hs.config.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -497,91 +588,19 @@ def run(hs): reactor.run = profile(reactor.run) clock = hs.get_clock() - start_time = clock.time() stats = {} - # Contains the list of processes we will be monitoring - # currently either 0 or 1 - stats_process = [] + def performance_stats_init(): + _stats_process.clear() + _stats_process.append( + (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) + ) def start_phone_stats_home(): - return run_as_background_process("phone_stats_home", phone_stats_home) - - @defer.inlineCallbacks - def phone_stats_home(): - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - uptime = int(now - start_time) - if uptime < 0: - uptime = 0 - - stats["homeserver"] = hs.config.server_name - stats["server_context"] = hs.config.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = yield hs.get_datastore().count_all_users() - - total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in iteritems(daily_user_type_results): - stats["daily_user_type_" + name] = count - - room_count = yield hs.get_datastore().get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats[ - "daily_active_rooms" - ] = yield hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() - - r30_results = yield hs.get_datastore().count_r30_users() - for name, count in iteritems(r30_results): - stats["r30_users_" + name] = count - - daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size - - if len(stats_process) > 0: - stats["memory_rss"] = 0 - stats["cpu_average"] = 0 - for process in stats_process: - stats["memory_rss"] += process.memory_info().rss - stats["cpu_average"] += int(process.cpu_percent(interval=None)) - - stats["database_engine"] = hs.get_datastore().database_engine_name - stats["database_server_version"] = hs.get_datastore().get_server_version() - logger.info( - "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats) + return run_as_background_process( + "phone_stats_home", phone_stats_home, hs, stats ) - try: - yield hs.get_simple_http_client().put_json( - hs.config.report_stats_endpoint, stats - ) - except Exception as e: - logger.warn("Error reporting stats: %s", e) - - def performance_stats_init(): - try: - process = psutil.Process() - # Ensure we can fetch both, and make the initial request for cpu_percent - # so the next request will use this as the initial point. - process.memory_info().rss - process.cpu_percent(interval=None) - logger.info("report_stats can use psutil") - stats_process.append(process) - except (AttributeError): - logger.warning("Unable to read memory/cpu stats. Disabling reporting.") def generate_user_daily_visit_stats(): return run_as_background_process( @@ -602,16 +621,23 @@ def run(hs): clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) reap_monthly_active_users() - @defer.inlineCallbacks - def generate_monthly_active_users(): + async def generate_monthly_active_users(): current_mau_count = 0 - reserved_count = 0 + current_mau_count_by_service = {} + reserved_users = () store = hs.get_datastore() if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - current_mau_count = yield store.get_monthly_active_count() - reserved_count = yield store.get_registered_reserved_users_count() + current_mau_count = await store.get_monthly_active_count() + current_mau_count_by_service = ( + await store.get_monthly_active_count_by_service() + ) + reserved_users = await store.get_registered_reserved_users() current_mau_gauge.set(float(current_mau_count)) - registered_reserved_users_mau_gauge.set(float(reserved_count)) + + for app_service, count in current_mau_count_by_service.items(): + current_mau_by_service_gauge.labels(app_service).set(float(count)) + + registered_reserved_users_mau_gauge.set(float(len(reserved_users))) max_mau_gauge.set(float(hs.config.max_mau_value)) def start_generate_monthly_active_users(): diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2ac783ffa3..add43147b3 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -13,169 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.api.urls import CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.media.v0.content_repository import ContentRepoResource -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.storage.media_repository import MediaRepositoryStore -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.media_repository") - - -class MediaRepositorySlavedStore( - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedClientIpStore, - SlavedTransactionStore, - BaseSlavedStore, - MediaRepositoryStore, -): - pass - - -class MediaRepositoryServer(HomeServer): - DATASTORE_CLASS = MediaRepositorySlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "media": - media_repo = self.get_media_repository_resource() - - # We need to serve the admin servlets for media on the - # worker. - admin_resource = JsonResource(self, canonical_json=False) - register_servlets_for_media_repo(self, admin_resource) - - resources.update( - { - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - "/_synapse/admin": admin_resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse media repository now listening on port %d", port) - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return ReplicationClientHandler(self.get_datastore()) - - -def start(config_options): - try: - config = HomeServerConfig.load_config( - "Synapse media repository", config_options - ) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.media_repository" - - if config.enable_media_repo: - _base.quit_with_error( - "enable_media_repo must be disabled in the main synapse process\n" - "before the media repo can be run in a separate worker.\n" - "Please add ``enable_media_repo: false`` to the main config\n" - ) - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - ss = MediaRepositoryServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-media-repository", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index d84732ee3c..add43147b3 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -13,214 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import __func__ -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.pushers import SlavedPusherStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.server import HomeServer -from synapse.storage import DataStore -from synapse.storage.engines import create_engine -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.pusher") - - -class PusherSlaveStore( - SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore -): - update_pusher_last_stream_ordering_and_success = __func__( - DataStore.update_pusher_last_stream_ordering_and_success - ) - - update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since) - - update_pusher_last_stream_ordering = __func__( - DataStore.update_pusher_last_stream_ordering - ) - - get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room) - - set_throttle_params = __func__(DataStore.set_throttle_params) - - get_time_of_last_push_action_before = __func__( - DataStore.get_time_of_last_push_action_before - ) - - get_profile_displayname = __func__(DataStore.get_profile_displayname) - - -class PusherServer(HomeServer): - DATASTORE_CLASS = PusherSlaveStore - - def remove_pusher(self, app_id, push_key, user_id): - self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse pusher now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - def build_tcp_replication(self): - return PusherReplicationHandler(self) - - -class PusherReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(PusherReplicationHandler, self).__init__(hs.get_datastore()) - - self.pusher_pool = hs.get_pusherpool() - - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) - run_in_background(self.poke_pushers, stream_name, token, rows) - - @defer.inlineCallbacks - def poke_pushers(self, stream_name, token, rows): - try: - if stream_name == "pushers": - for row in rows: - if row.deleted: - yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) - else: - yield self.start_pusher(row.user_id, row.app_id, row.pushkey) - elif stream_name == "events": - yield self.pusher_pool.on_new_notifications(token, token) - elif stream_name == "receipts": - yield self.pusher_pool.on_new_receipts( - token, token, set(row.room_id for row in rows) - ) - except Exception: - logger.exception("Error poking pushers") - - def stop_pusher(self, user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - pushers_for_user = self.pusher_pool.pushers.get(user_id, {}) - pusher = pushers_for_user.pop(key, None) - if pusher is None: - return - logger.info("Stopping pusher %r / %r", user_id, key) - pusher.on_stop() - - def start_pusher(self, user_id, app_id, pushkey): - key = "%s:%s" % (app_id, pushkey) - logger.info("Starting pusher %r / %r", user_id, key) - return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse pusher", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.pusher" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - if config.start_pushers: - sys.stderr.write( - "\nThe pushers must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``start_pushers: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.start_pushers = True - - database_engine = create_engine(config.database_config) - - ps = PusherServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ps, config, use_worker_options=True) - - ps.setup() - - def start(): - _base.start(ps, config.worker_listeners) - ps.get_pusherpool().start() - - reactor.addSystemEventTrigger("before", "startup", start) - - _base.start_worker_reactor("synapse-pusher", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): - ps = start(sys.argv[1:]) + start(sys.argv[1:]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 473026fce5..add43147b3 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -13,450 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import sys - -from six import iteritems - -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse.api.constants import EventTypes -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.handlers.presence import PresenceHandler, get_interested_parties -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore -from synapse.replication.slave.storage.devices import SlavedDeviceStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.filtering import SlavedFilteringStore -from synapse.replication.slave.storage.groups import SlavedGroupServerStore -from synapse.replication.slave.storage.presence import SlavedPresenceStore -from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams.events import EventsStreamEventRow -from synapse.rest.client.v1 import events -from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.room import RoomInitialSyncRestServlet -from synapse.rest.client.v2_alpha import sync -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.storage.presence import UserPresenceState -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.stringutils import random_string -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.synchrotron") - - -class SynchrotronSlavedStore( - SlavedReceiptsStore, - SlavedAccountDataStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedFilteringStore, - SlavedPresenceStore, - SlavedGroupServerStore, - SlavedDeviceInboxStore, - SlavedDeviceStore, - SlavedPushRuleStore, - SlavedEventStore, - SlavedClientIpStore, - RoomStore, - BaseSlavedStore, -): - pass - - -UPDATE_SYNCING_USERS_MS = 10 * 1000 - - -class SynchrotronPresence(object): - def __init__(self, hs): - self.hs = hs - self.is_mine_id = hs.is_mine_id - self.http_client = hs.get_simple_http_client() - self.store = hs.get_datastore() - self.user_to_num_current_syncs = {} - self.clock = hs.get_clock() - self.notifier = hs.get_notifier() - - active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = {state.user_id: state for state in active_presence} - - # user_id -> last_sync_ms. Lists the users that have stopped syncing - # but we haven't notified the master of that yet - self.users_going_offline = {} - - self._send_stop_syncing_loop = self.clock.looping_call( - self.send_stop_syncing, 10 * 1000 - ) - - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) - - def send_user_sync(self, user_id, is_syncing, last_sync_ms): - if self.hs.config.use_presence: - self.hs.get_tcp_replication().send_user_sync( - user_id, is_syncing, last_sync_ms - ) - - def mark_as_coming_online(self, user_id): - """A user has started syncing. Send a UserSync to the master, unless they - had recently stopped syncing. - - Args: - user_id (str) - """ - going_offline = self.users_going_offline.pop(user_id, None) - if not going_offline: - # Safe to skip because we haven't yet told the master they were offline - self.send_user_sync(user_id, True, self.clock.time_msec()) - - def mark_as_going_offline(self, user_id): - """A user has stopped syncing. We wait before notifying the master as - its likely they'll come back soon. This allows us to avoid sending - a stopped syncing immediately followed by a started syncing notification - to the master - - Args: - user_id (str) - """ - self.users_going_offline[user_id] = self.clock.time_msec() - - def send_stop_syncing(self): - """Check if there are any users who have stopped syncing a while ago - and haven't come back yet. If there are poke the master about them. - """ - now = self.clock.time_msec() - for user_id, last_sync_ms in list(self.users_going_offline.items()): - if now - last_sync_ms > 10 * 1000: - self.users_going_offline.pop(user_id, None) - self.send_user_sync(user_id, False, last_sync_ms) - - def set_state(self, user, state, ignore_status_msg=False): - # TODO Hows this supposed to work? - pass - - get_states = __func__(PresenceHandler.get_states) - get_state = __func__(PresenceHandler.get_state) - current_state_for_users = __func__(PresenceHandler.current_state_for_users) - - def user_syncing(self, user_id, affect_presence): - if affect_presence: - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 - - # If we went from no in flight sync to some, notify replication - if self.user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) - - def _end(): - # We check that the user_id is in user_to_num_current_syncs because - # user_to_num_current_syncs may have been cleared if we are - # shutting down. - if affect_presence and user_id in self.user_to_num_current_syncs: - self.user_to_num_current_syncs[user_id] -= 1 - - # If we went from one in flight sync to non, notify replication - if self.user_to_num_current_syncs[user_id] == 0: - self.mark_as_going_offline(user_id) - - @contextlib.contextmanager - def _user_syncing(): - try: - yield - finally: - _end() - - return defer.succeed(_user_syncing()) - - @defer.inlineCallbacks - def notify_from_replication(self, states, stream_id): - parties = yield get_interested_parties(self.store, states) - room_ids_to_states, users_to_states = parties - - self.notifier.on_new_event( - "presence_key", - stream_id, - rooms=room_ids_to_states.keys(), - users=users_to_states.keys(), - ) - - @defer.inlineCallbacks - def process_replication_rows(self, token, rows): - states = [ - UserPresenceState( - row.user_id, - row.state, - row.last_active_ts, - row.last_federation_update_ts, - row.last_user_sync_ts, - row.status_msg, - row.currently_active, - ) - for row in rows - ] - - for state in states: - self.user_to_current_state[state.user_id] = state - - stream_id = token - yield self.notify_from_replication(states, stream_id) - - def get_currently_syncing_users(self): - if self.hs.config.use_presence: - return [ - user_id - for user_id, count in iteritems(self.user_to_num_current_syncs) - if count > 0 - ] - else: - return set() - -class SynchrotronTyping(object): - def __init__(self, hs): - self._latest_room_serial = 0 - self._reset() - - def _reset(self): - """ - Reset the typing handler's data caches. - """ - # map room IDs to serial numbers - self._room_serials = {} - # map room IDs to sets of users currently typing - self._room_typing = {} - - def stream_positions(self): - # We must update this typing token from the response of the previous - # sync. In particular, the stream id may "reset" back to zero/a low - # value which we *must* use for the next replication request. - return {"typing": self._latest_room_serial} - - def process_replication_rows(self, token, rows): - if self._latest_room_serial > token: - # The master has gone backwards. To prevent inconsistent data, just - # clear everything. - self._reset() - - # Set the latest serial token to whatever the server gave us. - self._latest_room_serial = token - - for row in rows: - self._room_serials[row.room_id] = token - self._room_typing[row.room_id] = row.user_ids - - -class SynchrotronApplicationService(object): - def notify_interested_services(self, event): - pass - - -class SynchrotronServer(HomeServer): - DATASTORE_CLASS = SynchrotronSlavedStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - sync.register_servlets(self, resource) - events.register_servlets(self, resource) - InitialSyncRestServlet(self).register(resource) - RoomInitialSyncRestServlet(self).register(resource) - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse synchrotron now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return SyncReplicationHandler(self) - - def build_presence_handler(self): - return SynchrotronPresence(self) - - def build_typing_handler(self): - return SynchrotronTyping(self) - - -class SyncReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(SyncReplicationHandler, self).__init__(hs.get_datastore()) - - self.store = hs.get_datastore() - self.typing_handler = hs.get_typing_handler() - # NB this is a SynchrotronPresence, not a normal PresenceHandler - self.presence_handler = hs.get_presence_handler() - self.notifier = hs.get_notifier() - - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) - run_in_background(self.process_and_notify, stream_name, token, rows) - - def get_streams_to_replicate(self): - args = super(SyncReplicationHandler, self).get_streams_to_replicate() - args.update(self.typing_handler.stream_positions()) - return args - - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - - @defer.inlineCallbacks - def process_and_notify(self, stream_name, token, rows): - try: - if stream_name == "events": - # We shouldn't get multiple rows per token for events stream, so - # we don't need to optimise this for multiple rows. - for row in rows: - if row.type != EventsStreamEventRow.TypeId: - continue - event = yield self.store.get_event(row.data.event_id) - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event( - event, token, max_token, extra_users - ) - elif stream_name == "push_rules": - self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows] - ) - elif stream_name in ("account_data", "tag_account_data"): - self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows] - ) - elif stream_name == "receipts": - self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows] - ) - elif stream_name == "typing": - self.typing_handler.process_replication_rows(token, rows) - self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows] - ) - elif stream_name == "to_device": - entities = [row.entity for row in rows if row.entity.startswith("@")] - if entities: - self.notifier.on_new_event("to_device_key", token, users=entities) - elif stream_name == "device_lists": - all_room_ids = set() - for row in rows: - room_ids = yield self.store.get_rooms_for_user(row.user_id) - all_room_ids.update(room_ids) - self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) - elif stream_name == "presence": - yield self.presence_handler.process_replication_rows(token, rows) - elif stream_name == "receipts": - self.notifier.on_new_event( - "groups_key", token, users=[row.user_id for row in rows] - ) - except Exception: - logger.exception("Error processing replication") - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse synchrotron", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.synchrotron" - - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - ss = SynchrotronServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - application_service_handler=SynchrotronApplicationService(), - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-synchrotron", config) +import sys +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index e01afb39f2..503d44f687 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -14,222 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import sys -from twisted.internet import defer, reactor -from twisted.web.resource import NoResource - -import synapse -from synapse import events -from synapse.app import _base -from synapse.config._base import ConfigError -from synapse.config.homeserver import HomeServerConfig -from synapse.config.logger import setup_logging -from synapse.http.server import JsonResource -from synapse.http.site import SynapseSite -from synapse.logging.context import LoggingContext, run_in_background -from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore -from synapse.replication.slave.storage.client_ips import SlavedClientIpStore -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams.events import ( - EventsStream, - EventsStreamCurrentStateRow, -) -from synapse.rest.client.v2_alpha import user_directory -from synapse.server import HomeServer -from synapse.storage.engines import create_engine -from synapse.storage.user_directory import UserDirectoryStore -from synapse.util.caches.stream_change_cache import StreamChangeCache -from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.manhole import manhole -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger("synapse.app.user_dir") - - -class UserDirectorySlaveStore( - SlavedEventStore, - SlavedApplicationServiceStore, - SlavedRegistrationStore, - SlavedClientIpStore, - UserDirectoryStore, - BaseSlavedStore, -): - def __init__(self, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(db_conn, hs) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - def stream_positions(self): - result = super(UserDirectorySlaveStore, self).stream_positions() - return result - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(token) - for row in rows: - if row.type != EventsStreamCurrentStateRow.TypeId: - continue - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - return super(UserDirectorySlaveStore, self).process_replication_rows( - stream_name, token, rows - ) - - -class UserDirectoryServer(HomeServer): - DATASTORE_CLASS = UserDirectorySlaveStore - - def _listen_http(self, listener_config): - port = listener_config["port"] - bind_addresses = listener_config["bind_addresses"] - site_tag = listener_config.get("tag", port) - resources = {} - for res in listener_config["resources"]: - for name in res["names"]: - if name == "metrics": - resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) - elif name == "client": - resource = JsonResource(self, canonical_json=False) - user_directory.register_servlets(self, resource) - resources.update( - { - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - } - ) - - root_resource = create_resource_tree(resources, NoResource()) - - _base.listen_tcp( - bind_addresses, - port, - SynapseSite( - "synapse.access.http.%s" % (site_tag,), - site_tag, - listener_config, - root_resource, - self.version_string, - ), - ) - - logger.info("Synapse user_dir now listening on port %d", port) - - def start_listening(self, listeners): - for listener in listeners: - if listener["type"] == "http": - self._listen_http(listener) - elif listener["type"] == "manhole": - _base.listen_tcp( - listener["bind_addresses"], - listener["port"], - manhole( - username="matrix", password="rabbithole", globals={"hs": self} - ), - ) - elif listener["type"] == "metrics": - if not self.get_config().enable_metrics: - logger.warn( - ( - "Metrics listener configured, but " - "enable_metrics is not True!" - ) - ) - else: - _base.listen_metrics(listener["bind_addresses"], listener["port"]) - else: - logger.warn("Unrecognized listener type: %s", listener["type"]) - - self.get_tcp_replication().start_replication(self) - - def build_tcp_replication(self): - return UserDirectoryReplicationHandler(self) - - -class UserDirectoryReplicationHandler(ReplicationClientHandler): - def __init__(self, hs): - super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) - self.user_directory = hs.get_user_directory_handler() - - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(UserDirectoryReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - if stream_name == EventsStream.NAME: - run_in_background(self._notify_directory) - - @defer.inlineCallbacks - def _notify_directory(self): - try: - yield self.user_directory.notify_new_event() - except Exception: - logger.exception("Error notifiying user directory of state update") - - -def start(config_options): - try: - config = HomeServerConfig.load_config("Synapse user directory", config_options) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") - sys.exit(1) - - assert config.worker_app == "synapse.app.user_dir" - - events.USE_FROZEN_DICTS = config.use_frozen_dicts - - database_engine = create_engine(config.database_config) - - if config.update_user_directory: - sys.stderr.write( - "\nThe update_user_directory must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``update_user_directory: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.update_user_directory = True - - ss = UserDirectoryServer( - config.server_name, - db_config=config.database_config, - config=config, - version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, - ) - - setup_logging(ss, config, use_worker_options=True) - - ss.setup() - reactor.addSystemEventTrigger( - "before", "startup", _base.start, ss, config.worker_listeners - ) - - _base.start_worker_reactor("synapse-user-dir", config) - +from synapse.app.generic_worker import start +from synapse.util.logcontext import LoggingContext if __name__ == "__main__": with LoggingContext("main"): diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 33b3579425..1b13e84425 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -94,7 +94,9 @@ class ApplicationService(object): ip_range_whitelist=None, ): self.token = token - self.url = url + self.url = ( + url.rstrip("/") if isinstance(url, str) else None + ) # url must not end with a slash self.hs_token = hs_token self.sender = sender self.server_name = hostname @@ -268,7 +270,7 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exlusive_user_regexes(self): + def get_exclusive_user_regexes(self): """Get the list of regexes used to determine if a user is exclusively registered by the AS """ diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 3e25bf5747..57174da021 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -185,7 +185,7 @@ class ApplicationServiceApi(SimpleHttpClient): if not _is_valid_3pe_metadata(info): logger.warning( - "query_3pe_protocol to %s did not return a" " valid result", uri + "query_3pe_protocol to %s did not return a valid result", uri ) return None diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 31f6530978..30d1050a91 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,7 +18,9 @@ import argparse import errno import os +from collections import OrderedDict from textwrap import dedent +from typing import Any, MutableMapping, Optional from six import integer_types @@ -51,7 +53,68 @@ Missing mandatory `server_name` config option. """ +CONFIG_FILE_HEADER = """\ +# Configuration file for Synapse. +# +# This is a YAML file: see [1] for a quick introduction. Note in particular +# that *indentation is important*: all the elements of a list or dictionary +# should have the same indentation. +# +# [1] https://docs.ansible.com/ansible/latest/reference_appendices/YAMLSyntax.html + +""" + + +def path_exists(file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + + class Config(object): + """ + A configuration section, containing configuration keys and values. + + Attributes: + section (str): The section title of this config object, such as + "tls" or "logger". This is used to refer to it on the root + logger (for example, `config.tls.some_option`). Must be + defined in subclasses. + """ + + section = None + + def __init__(self, root_config=None): + self.root = root_config + + def __getattr__(self, item: str) -> Any: + """ + Try and fetch a configuration option that does not exist on this class. + + This is so that existing configs that rely on `self.value`, where value + is actually from a different config section, continue to work. + """ + if item in ["generate_config_section", "read_config"]: + raise AttributeError(item) + + if self.root is None: + raise AttributeError(item) + else: + return self.root._get_unclassed_config(self.section, item) + @staticmethod def parse_size(value): if isinstance(value, integer_types): @@ -88,22 +151,7 @@ class Config(object): @classmethod def path_exists(cls, file_path): - """Check if a file exists - - Unlike os.path.exists, this throws an exception if there is an error - checking if the file exists (for example, if there is a perms error on - the parent dir). - - Returns: - bool: True if the file exists; False if not. - """ - try: - os.stat(file_path) - return True - except OSError as e: - if e.errno != errno.ENOENT: - raise e - return False + return path_exists(file_path) @classmethod def check_file(cls, file_path, config_name): @@ -136,42 +184,106 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() - def invoke_all(self, name, *args, **kargs): - """Invoke all instance methods with the given name and arguments in the - class's MRO. + +class RootConfig(object): + """ + Holder of an application's configuration. + + What configuration this object holds is defined by `config_classes`, a list + of Config classes that will be instantiated and given the contents of a + configuration file to read. They can then be accessed on this class by their + section name, defined in the Config or dynamically set to be the name of the + class, lower-cased and with "Config" removed. + """ + + config_classes = [] + + def __init__(self): + self._configs = OrderedDict() + + for config_class in self.config_classes: + if config_class.section is None: + raise ValueError("%r requires a section name" % (config_class,)) + + try: + conf = config_class(self) + except Exception as e: + raise Exception("Failed making %s: %r" % (config_class.section, e)) + self._configs[config_class.section] = conf + + def __getattr__(self, item: str) -> Any: + """ + Redirect lookups on this object either to config objects, or values on + config objects, so that `config.tls.blah` works, as well as legacy uses + of things like `config.server_name`. It will first look up the config + section name, and then values on those config classes. + """ + if item in self._configs.keys(): + return self._configs[item] + + return self._get_unclassed_config(None, item) + + def _get_unclassed_config(self, asking_section: Optional[str], item: str): + """ + Fetch a config value from one of the instantiated config classes that + has not been fetched directly. Args: - name (str): Name of function to invoke + asking_section: If this check is coming from a Config child, which + one? This section will not be asked if it has the value. + item: The configuration value key. + + Raises: + AttributeError if no config classes have the config key. The body + will contain what sections were checked. + """ + for key, val in self._configs.items(): + if key == asking_section: + continue + + if item in dir(val): + return getattr(val, item) + + raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),)) + + def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: + """ + Invoke a function on all instantiated config objects this RootConfig is + configured to use. + + Args: + func_name: Name of function to invoke *args **kwargs - Returns: - list: The list of the return values from each method called + ordered dictionary of config section name and the result of the + function from it. """ - results = [] - for cls in type(self).mro(): - if name in cls.__dict__: - results.append(getattr(cls, name)(self, *args, **kargs)) - return results + res = OrderedDict() + + for name, config in self._configs.items(): + if hasattr(config, func_name): + res[name] = getattr(config, func_name)(*args, **kwargs) + + return res @classmethod - def invoke_all_static(cls, name, *args, **kargs): - """Invoke all static methods with the given name and arguments in the - class's MRO. + def invoke_all_static(cls, func_name: str, *args, **kwargs): + """ + Invoke a static function on config objects this RootConfig is + configured to use. Args: - name (str): Name of function to invoke + func_name: Name of function to invoke *args **kwargs - Returns: - list: The list of the return values from each method called + ordered dictionary of config section name and the result of the + function from it. """ - results = [] - for c in cls.mro(): - if name in c.__dict__: - results.append(getattr(c, name)(*args, **kargs)) - return results + for config in cls.config_classes: + if hasattr(config, func_name): + getattr(config, func_name)(*args, **kwargs) def generate_config( self, @@ -182,12 +294,12 @@ class Config(object): report_stats=None, open_private_ports=False, listeners=None, - database_conf=None, tls_certificate_path=None, tls_private_key_path=None, acme_domain=None, ): - """Build a default configuration file + """ + Build a default configuration file This is used when the user explicitly asks us to generate a config file (eg with --generate_config). @@ -242,7 +354,8 @@ class Config(object): Returns: str: the yaml config file """ - return "\n\n".join( + + return CONFIG_FILE_HEADER + "\n\n".join( dedent(conf) for conf in self.invoke_all( "generate_config_section", @@ -253,11 +366,10 @@ class Config(object): report_stats=report_stats, open_private_ports=open_private_ports, listeners=listeners, - database_conf=database_conf, tls_certificate_path=tls_certificate_path, tls_private_key_path=tls_private_key_path, acme_domain=acme_domain, - ) + ).values() ) @classmethod @@ -356,8 +468,8 @@ class Config(object): Returns: Config object, or None if --generate-config or --generate-keys was set """ - config_parser = argparse.ArgumentParser(add_help=False) - config_parser.add_argument( + parser = argparse.ArgumentParser(description=description) + parser.add_argument( "-c", "--config-path", action="append", @@ -366,7 +478,7 @@ class Config(object): " may specify directories containing *.yaml files.", ) - generate_group = config_parser.add_argument_group("Config generation") + generate_group = parser.add_argument_group("Config generation") generate_group.add_argument( "--generate-config", action="store_true", @@ -414,12 +526,13 @@ class Config(object): ), ) - config_args, remaining_args = config_parser.parse_known_args(argv) + cls.invoke_all_static("add_arguments", parser) + config_args = parser.parse_args(argv) config_files = find_config_files(search_paths=config_args.config_path) if not config_files: - config_parser.error( + parser.error( "Must supply a config file.\nA config file can be automatically" ' generated using "--generate-config -H SERVER_NAME' ' -c CONFIG-FILE"' @@ -438,13 +551,13 @@ class Config(object): if config_args.generate_config: if config_args.report_stats is None: - config_parser.error( + parser.error( "Please specify either --report-stats=yes or --report-stats=no\n\n" + MISSING_REPORT_STATS_SPIEL ) (config_path,) = config_files - if not cls.path_exists(config_path): + if not path_exists(config_path): print("Generating config file %s" % (config_path,)) if config_args.data_directory: @@ -469,11 +582,11 @@ class Config(object): open_private_ports=config_args.open_private_ports, ) - if not cls.path_exists(config_dir_path): + if not path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: - config_file.write("# vim:ft=yaml\n\n") config_file.write(config_str) + config_file.write("\n\n# vim:ft=yaml") config_dict = yaml.safe_load(config_str) obj.generate_missing_files(config_dict, config_dir_path) @@ -497,15 +610,6 @@ class Config(object): ) generate_missing_configs = True - parser = argparse.ArgumentParser( - parents=[config_parser], - description=description, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - obj.invoke_all_static("add_arguments", parser) - args = parser.parse_args(remaining_args) - config_dict = read_config_files(config_files) if generate_missing_configs: obj.generate_missing_files(config_dict, config_dir_path) @@ -514,11 +618,11 @@ class Config(object): obj.parse_config_dict( config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path ) - obj.invoke_all("read_arguments", args) + obj.invoke_all("read_arguments", config_args) return obj - def parse_config_dict(self, config_dict, config_dir_path, data_dir_path): + def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None): """Read the information from the config dict into this Config object. Args: @@ -553,6 +657,12 @@ def read_config_files(config_files): for config_file in config_files: with open(config_file) as file_stream: yaml_config = yaml.safe_load(file_stream) + + if not isinstance(yaml_config, dict): + err = "File %r is empty or doesn't parse into a key-value map. IGNORING." + print(err % (config_file,)) + continue + specified_config.update(yaml_config) if "server_name" not in specified_config: @@ -607,3 +717,6 @@ def find_config_files(search_paths): else: config_files.append(config_path) return config_files + + +__all__ = ["Config", "RootConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi new file mode 100644 index 0000000000..9e576060d4 --- /dev/null +++ b/synapse/config/_base.pyi @@ -0,0 +1,139 @@ +from typing import Any, List, Optional + +from synapse.config import ( + api, + appservice, + captcha, + cas, + consent_config, + database, + emailconfig, + groups, + jwt_config, + key, + logger, + metrics, + oidc_config, + password, + password_auth_providers, + push, + ratelimiting, + registration, + repository, + room_directory, + saml2_config, + server, + server_notices_config, + spam_checker, + sso, + stats, + third_party_event_rules, + tls, + tracer, + user_directory, + voip, + workers, +) + +class ConfigError(Exception): ... + +MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str +MISSING_REPORT_STATS_SPIEL: str +MISSING_SERVER_NAME: str + +def path_exists(file_path: str): ... + +class RootConfig: + server: server.ServerConfig + tls: tls.TlsConfig + database: database.DatabaseConfig + logging: logger.LoggingConfig + ratelimit: ratelimiting.RatelimitConfig + media: repository.ContentRepositoryConfig + captcha: captcha.CaptchaConfig + voip: voip.VoipConfig + registration: registration.RegistrationConfig + metrics: metrics.MetricsConfig + api: api.ApiConfig + appservice: appservice.AppServiceConfig + key: key.KeyConfig + saml2: saml2_config.SAML2Config + cas: cas.CasConfig + sso: sso.SSOConfig + oidc: oidc_config.OIDCConfig + jwt: jwt_config.JWTConfig + password: password.PasswordConfig + email: emailconfig.EmailConfig + worker: workers.WorkerConfig + authproviders: password_auth_providers.PasswordAuthProviderConfig + push: push.PushConfig + spamchecker: spam_checker.SpamCheckerConfig + groups: groups.GroupsConfig + userdirectory: user_directory.UserDirectoryConfig + consent: consent_config.ConsentConfig + stats: stats.StatsConfig + servernotices: server_notices_config.ServerNoticesConfig + roomdirectory: room_directory.RoomDirectoryConfig + thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig + tracer: tracer.TracerConfig + + config_classes: List = ... + def __init__(self) -> None: ... + def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ... + @classmethod + def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ... + def __getattr__(self, item: str): ... + def parse_config_dict( + self, + config_dict: Any, + config_dir_path: Optional[Any] = ..., + data_dir_path: Optional[Any] = ..., + ) -> None: ... + read_config: Any = ... + def generate_config( + self, + config_dir_path: str, + data_dir_path: str, + server_name: str, + generate_secrets: bool = ..., + report_stats: Optional[str] = ..., + open_private_ports: bool = ..., + listeners: Optional[Any] = ..., + database_conf: Optional[Any] = ..., + tls_certificate_path: Optional[str] = ..., + tls_private_key_path: Optional[str] = ..., + acme_domain: Optional[str] = ..., + ): ... + @classmethod + def load_or_generate_config(cls, description: Any, argv: Any): ... + @classmethod + def load_config(cls, description: Any, argv: Any): ... + @classmethod + def add_arguments_to_parser(cls, config_parser: Any) -> None: ... + @classmethod + def load_config_with_parser(cls, parser: Any, argv: Any): ... + def generate_missing_files( + self, config_dict: dict, config_dir_path: str + ) -> None: ... + +class Config: + root: RootConfig + def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ... + def __getattr__(self, item: str, from_root: bool = ...): ... + @staticmethod + def parse_size(value: Any): ... + @staticmethod + def parse_duration(value: Any): ... + @staticmethod + def abspath(file_path: Optional[str]): ... + @classmethod + def path_exists(cls, file_path: str): ... + @classmethod + def check_file(cls, file_path: str, config_name: str): ... + @classmethod + def ensure_directory(cls, dir_path: str): ... + @classmethod + def read_file(cls, file_path: str, config_name: str): ... + +def read_config_files(config_files: List[str]): ... +def find_config_files(search_paths: List[str]): ... diff --git a/synapse/config/api.py b/synapse/config/api.py index dddea79a8a..74cd53a8ed 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -18,6 +18,8 @@ from ._base import Config class ApiConfig(Config): + section = "api" + def read_config(self, config, **kwargs): self.room_invite_state_types = config.get( "room_invite_state_types", diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 8387ff6805..ca43e96bd1 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import Dict from six import string_types from six.moves.urllib import parse as urlparse @@ -29,6 +30,8 @@ logger = logging.getLogger(__name__) class AppServiceConfig(Config): + section = "appservice" + def read_config(self, config, **kwargs): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) @@ -45,7 +48,7 @@ class AppServiceConfig(Config): # Uncomment to enable tracking of application service IP addresses. Implicitly # enables MAU tracking for application service users. # - #track_appservice_user_ips: True + #track_appservice_user_ips: true """ @@ -56,8 +59,8 @@ def load_appservices(hostname, config_files): return [] # Dicts of value -> filename - seen_as_tokens = {} - seen_ids = {} + seen_as_tokens = {} # type: Dict[str, str] + seen_ids = {} # type: Dict[str, str] appservices = [] @@ -131,7 +134,7 @@ def _load_appservice(hostname, as_info, config_filename): for regex_obj in as_info["namespaces"][ns]: if not isinstance(regex_obj, dict): raise ValueError( - "Expected namespace entry in %s to be an object," " but got %s", + "Expected namespace entry in %s to be an object, but got %s", ns, regex_obj, ) diff --git a/synapse/config/cache.py b/synapse/config/cache.py new file mode 100644 index 0000000000..0672538796 --- /dev/null +++ b/synapse/config/cache.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from typing import Callable, Dict + +from ._base import Config, ConfigError + +# The prefix for all cache factor-related environment variables +_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" + +# Map from canonicalised cache name to cache. +_CACHES = {} + +_DEFAULT_FACTOR_SIZE = 0.5 +_DEFAULT_EVENT_CACHE_SIZE = "10K" + + +class CacheProperties(object): + def __init__(self): + # The default factor size for all caches + self.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + self.resize_all_caches_func = None + + +properties = CacheProperties() + + +def _canonicalise_cache_name(cache_name: str) -> str: + """Gets the canonical form of the cache name. + + Since we specify cache names in config and environment variables we need to + ignore case and special characters. For example, some caches have asterisks + in their name to denote that they're not attached to a particular database + function, and these asterisks need to be stripped out + """ + + cache_name = re.sub(r"[^A-Za-z_1-9]", "", cache_name) + + return cache_name.lower() + + +def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): + """Register a cache that's size can dynamically change + + Args: + cache_name: A reference to the cache + cache_resize_callback: A callback function that will be ran whenever + the cache needs to be resized + """ + # Some caches have '*' in them which we strip out. + cache_name = _canonicalise_cache_name(cache_name) + + _CACHES[cache_name] = cache_resize_callback + + # Ensure all loaded caches are sized appropriately + # + # This method should only run once the config has been read, + # as it uses values read from it + if properties.resize_all_caches_func: + properties.resize_all_caches_func() + + +class CacheConfig(Config): + section = "caches" + _environ = os.environ + + @staticmethod + def reset(): + """Resets the caches to their defaults. Used for tests.""" + properties.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + properties.resize_all_caches_func = None + _CACHES.clear() + + def generate_config_section(self, **kwargs): + return """\ + ## Caching ## + + # Caching can be configured through the following options. + # + # A cache 'factor' is a multiplier that can be applied to each of + # Synapse's caches in order to increase or decrease the maximum + # number of entries that can be stored. + + # The number of events to cache in memory. Not affected by + # caches.global_factor. + # + #event_cache_size: 10K + + caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + # Some caches have '*' and other characters that are not + # alphanumeric or underscores. These caches can be named with or + # without the special characters stripped. For example, to specify + # the cache factor for `*stateGroupCache*` via an environment + # variable would be `SYNAPSE_CACHE_FACTOR_STATEGROUPCACHE=2.0`. + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + """ + + def read_config(self, config, **kwargs): + self.event_cache_size = self.parse_size( + config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) + ) + self.cache_factors = {} # type: Dict[str, float] + + cache_config = config.get("caches") or {} + self.global_factor = cache_config.get( + "global_factor", properties.default_factor_size + ) + if not isinstance(self.global_factor, (int, float)): + raise ConfigError("caches.global_factor must be a number.") + + # Set the global one so that it's reflected in new caches + properties.default_factor_size = self.global_factor + + # Load cache factors from the config + individual_factors = cache_config.get("per_cache_factors") or {} + if not isinstance(individual_factors, dict): + raise ConfigError("caches.per_cache_factors must be a dictionary") + + # Canonicalise the cache names *before* updating with the environment + # variables. + individual_factors = { + _canonicalise_cache_name(key): val + for key, val in individual_factors.items() + } + + # Override factors from environment if necessary + individual_factors.update( + { + _canonicalise_cache_name(key[len(_CACHE_PREFIX) + 1 :]): float(val) + for key, val in self._environ.items() + if key.startswith(_CACHE_PREFIX + "_") + } + ) + + for cache, factor in individual_factors.items(): + if not isinstance(factor, (int, float)): + raise ConfigError( + "caches.per_cache_factors.%s must be a number" % (cache,) + ) + self.cache_factors[cache] = factor + + # Resize all caches (if necessary) with the new factors we've loaded + self.resize_all_caches() + + # Store this function so that it can be called from other classes without + # needing an instance of Config + properties.resize_all_caches_func = self.resize_all_caches + + def resize_all_caches(self): + """Ensure all cache sizes are up to date + + For each cache, run the mapped callback function with either + a specific cache factor or the default, global one. + """ + for cache_name, callback in _CACHES.items(): + new_factor = self.cache_factors.get(cache_name, self.global_factor) + callback(new_factor) diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 8dac8152cf..82f04d7966 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -16,13 +16,14 @@ from ._base import Config class CaptchaConfig(Config): + section = "captcha" + def read_config(self, config, **kwargs): self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_public_key = config.get("recaptcha_public_key") self.enable_registration_captcha = config.get( "enable_registration_captcha", False ) - self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.recaptcha_siteverify_api = config.get( "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", @@ -31,27 +32,26 @@ class CaptchaConfig(Config): def generate_config_section(self, **kwargs): return """\ ## Captcha ## - # See docs/CAPTCHA_SETUP for full details of configuring this. + # See docs/CAPTCHA_SETUP.md for full details of configuring this. - # This Home Server's ReCAPTCHA public key. + # This homeserver's ReCAPTCHA public key. Must be specified if + # enable_registration_captcha is enabled. # #recaptcha_public_key: "YOUR_PUBLIC_KEY" - # This Home Server's ReCAPTCHA private key. + # This homeserver's ReCAPTCHA private key. Must be specified if + # enable_registration_captcha is enabled. # #recaptcha_private_key: "YOUR_PRIVATE_KEY" - # Enables ReCaptcha checks when registering, preventing signup + # Uncomment to enable ReCaptcha checks when registering, preventing signup # unless a captcha is answered. Requires a valid ReCaptcha - # public/private key. - # - #enable_registration_captcha: false - - # A secret key used to bypass the captcha test entirely. + # public/private key. Defaults to 'false'. # - #captcha_bypass_secret: "YOUR_SECRET_HERE" + #enable_registration_captcha: true # The API endpoint to use for verifying m.login.recaptcha responses. + # Defaults to "https://www.recaptcha.net/recaptcha/api/siteverify". # - #recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify" + #recaptcha_siteverify_api: "https://my.recaptcha.site" """ diff --git a/synapse/config/cas.py b/synapse/config/cas.py index ebe34d933b..4526c1a67b 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -22,17 +22,21 @@ class CasConfig(Config): cas_server_url: URL of CAS server """ + section = "cas" + def read_config(self, config, **kwargs): cas_config = config.get("cas_config", None) if cas_config: self.cas_enabled = cas_config.get("enabled", True) self.cas_server_url = cas_config["server_url"] self.cas_service_url = cas_config["service_url"] + self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes", {}) else: self.cas_enabled = False self.cas_server_url = None self.cas_service_url = None + self.cas_displayname_attribute = None self.cas_required_attributes = {} def generate_config_section(self, config_dir_path, server_name, **kwargs): @@ -43,6 +47,7 @@ class CasConfig(Config): # enabled: true # server_url: "https://cas-server.com" # service_url: "https://homeserver.domain.com:8448" + # #displayname_attribute: name # #required_attributes: # # name: value """ diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index 94916f3a49..aec9c4bbce 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -62,19 +62,22 @@ DEFAULT_CONFIG = """\ # body: >- # To continue using this homeserver you must review and agree to the # terms and conditions at %(consent_uri)s -# send_server_notice_to_guests: True +# send_server_notice_to_guests: true # block_events_error: >- # To continue using this homeserver you must review and agree to the # terms and conditions at %(consent_uri)s -# require_at_registration: False +# require_at_registration: false # policy_name: Privacy Policy # """ class ConsentConfig(Config): - def __init__(self): - super(ConsentConfig, self).__init__() + + section = "consent" + + def __init__(self, *args): + super(ConsentConfig, self).__init__(*args) self.user_consent_version = None self.user_consent_template_dir = None diff --git a/synapse/config/database.py b/synapse/config/database.py index 118aafbd4a..1064c2697b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,73 +13,185 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os -from textwrap import indent -import yaml +from synapse.config._base import Config, ConfigError -from ._base import Config +logger = logging.getLogger(__name__) +NON_SQLITE_DATABASE_PATH_WARNING = """\ +Ignoring 'database_path' setting: not using a sqlite3 database. +-------------------------------------------------------------------------------- +""" -class DatabaseConfig(Config): - def read_config(self, config, **kwargs): - self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) +DEFAULT_CONFIG = """\ +## Database ## + +# The 'database' setting defines the database that synapse uses to store all of +# its data. +# +# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or +# 'psycopg2' (for PostgreSQL). +# +# 'args' gives options which are passed through to the database engine, +# except for options starting 'cp_', which are used to configure the Twisted +# connection pool. For a reference to valid arguments, see: +# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS +# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__ +# +# +# Example SQLite configuration: +# +#database: +# name: sqlite3 +# args: +# database: /path/to/homeserver.db +# +# +# Example Postgres configuration: +# +#database: +# name: psycopg2 +# args: +# user: synapse +# password: secretpassword +# database: synapse +# host: localhost +# cp_min: 5 +# cp_max: 10 +# +# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# +database: + name: sqlite3 + args: + database: %(database_path)s +""" + + +class DatabaseConnectionConfig: + """Contains the connection config for a particular database. - self.database_config = config.get("database") + Args: + name: A label for the database, used for logging. + db_config: The config for a particular database, as per `database` + section of main config. Has three fields: `name` for database + module name, `args` for the args to give to the database + connector, and optional `data_stores` that is a list of stores to + provision on this database (defaulting to all). + """ - if self.database_config is None: - self.database_config = {"name": "sqlite3", "args": {}} + def __init__(self, name: str, db_config: dict): + db_engine = db_config.get("name", "sqlite3") - name = self.database_config.get("name", None) - if name == "psycopg2": - pass - elif name == "sqlite3": - self.database_config.setdefault("args", {}).update( + if db_engine not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_engine,)) + + if db_engine == "sqlite3": + db_config.setdefault("args", {}).update( {"cp_min": 1, "cp_max": 1, "check_same_thread": False} ) - else: - raise RuntimeError("Unsupported database type '%s'" % (name,)) - - self.set_databasepath(config.get("database_path")) - - def generate_config_section(self, data_dir_path, database_conf, **kwargs): - if not database_conf: - database_path = os.path.join(data_dir_path, "homeserver.db") - database_conf = ( - """# The database engine name - name: "sqlite3" - # Arguments to pass to the engine - args: - # Path to the database - database: "%(database_path)s" - """ - % locals() - ) - else: - database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip() - return ( - """\ - ## Database ## + data_stores = db_config.get("data_stores") + if data_stores is None: + data_stores = ["main", "state"] + + self.name = name + self.config = db_config + self.data_stores = data_stores + + +class DatabaseConfig(Config): + section = "database" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - database: - %(database_conf)s - # Number of events to cache in memory. + self.databases = [] + + def read_config(self, config, **kwargs): + # We *experimentally* support specifying multiple databases via the + # `databases` key. This is a map from a label to database config in the + # same format as the `database` config option, plus an extra + # `data_stores` key to specify which data store goes where. For example: # - #event_cache_size: 10K - """ - % locals() - ) + # databases: + # master: + # name: psycopg2 + # data_stores: ["main"] + # args: {} + # state: + # name: psycopg2 + # data_stores: ["state"] + # args: {} + + multi_database_config = config.get("databases") + database_config = config.get("database") + database_path = config.get("database_path") + + if multi_database_config and database_config: + raise ConfigError("Can't specify both 'database' and 'databases' in config") + + if multi_database_config: + if database_path: + raise ConfigError("Can't specify 'database_path' with 'databases'") + + self.databases = [ + DatabaseConnectionConfig(name, db_conf) + for name, db_conf in multi_database_config.items() + ] + + if database_config: + self.databases = [DatabaseConnectionConfig("master", database_config)] + + if database_path: + if self.databases and self.databases[0].name != "sqlite3": + logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) + return + + database_config = {"name": "sqlite3", "args": {}} + self.databases = [DatabaseConnectionConfig("master", database_config)] + self.set_databasepath(database_path) + + def generate_config_section(self, data_dir_path, **kwargs): + return DEFAULT_CONFIG % { + "database_path": os.path.join(data_dir_path, "homeserver.db") + } def read_arguments(self, args): - self.set_databasepath(args.database_path) + """ + Cases for the cli input: + - If no databases are configured and no database_path is set, raise. + - No databases and only database_path available ==> sqlite3 db. + - If there are multiple databases and a database_path raise an error. + - If the database set in the config file is sqlite then + overwrite with the command line argument. + """ + + if args.database_path is None: + if not self.databases: + raise ConfigError("No database config provided") + return + + if len(self.databases) == 0: + database_config = {"name": "sqlite3", "args": {}} + self.databases = [DatabaseConnectionConfig("master", database_config)] + self.set_databasepath(args.database_path) + return + + if self.get_single_database().name == "sqlite3": + self.set_databasepath(args.database_path) + else: + logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) def set_databasepath(self, database_path): + if database_path != ":memory:": database_path = self.abspath(database_path) - if self.database_config.get("name", None) == "sqlite3": - if database_path is not None: - self.database_config["args"]["database"] = database_path + + self.databases[0].config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -89,3 +202,11 @@ class DatabaseConfig(Config): metavar="SQLITE_DATABASE_PATH", help="The path to a sqlite database to use.", ) + + def get_single_database(self) -> DatabaseConnectionConfig: + """Returns the database if there is only one, useful for e.g. tests + """ + if not self.databases: + raise Exception("More than one database exists") + + return self.databases[0] diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index e5de768b0c..ca61214454 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -21,23 +21,34 @@ from __future__ import print_function import email.utils import os from enum import Enum +from typing import Optional import pkg_resources from ._base import Config, ConfigError +MISSING_PASSWORD_RESET_CONFIG_ERROR = """\ +Password reset emails are enabled on this homeserver due to a partial +'email' block. However, the following required keys are missing: + %s +""" + class EmailConfig(Config): + section = "email" + def read_config(self, config, **kwargs): # TODO: We should separate better the email configuration from the notification # and account validity config. self.email_enable_notifs = False - email_config = config.get("email", {}) + email_config = config.get("email") + if email_config is None: + email_config = {} - self.email_smtp_host = email_config.get("smtp_host", None) - self.email_smtp_port = email_config.get("smtp_port", None) + self.email_smtp_host = email_config.get("smtp_host", "localhost") + self.email_smtp_port = email_config.get("smtp_port", 25) self.email_smtp_user = email_config.get("smtp_user", None) self.email_smtp_pass = email_config.get("smtp_pass", None) self.require_transport_security = email_config.get( @@ -71,9 +82,9 @@ class EmailConfig(Config): self.email_template_dir = os.path.abspath(template_dir) self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_renewal_enabled = config.get("account_validity", {}).get( - "renew_at" - ) + + account_validity_config = config.get("account_validity") or {} + account_validity_renewal_enabled = account_validity_config.get("renew_at") self.threepid_behaviour_email = ( # Have Synapse handle the email sending if account_threepid_delegates.email @@ -97,9 +108,14 @@ class EmailConfig(Config): if self.trusted_third_party_id_servers: # XXX: It's a little confusing that account_threepid_delegate_email is modified # both in RegistrationConfig and here. We should factor this bit out - self.account_threepid_delegate_email = self.trusted_third_party_id_servers[ - 0 - ] + + first_trusted_identity_server = self.trusted_third_party_id_servers[0] + + # trusted_third_party_id_servers does not contain a scheme whereas + # account_threepid_delegate_email is expected to. Presume https + self.account_threepid_delegate_email = ( + "https://" + first_trusted_identity_server + ) # type: Optional[str] self.using_identity_server_from_trusted_list = True else: raise ConfigError( @@ -137,22 +153,18 @@ class EmailConfig(Config): bleach if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - required = ["smtp_host", "smtp_port", "notif_from"] - missing = [] - for k in required: - if k not in email_config: - missing.append("email." + k) + if not self.email_notif_from: + missing.append("email.notif_from") + # public_baseurl is required to build password reset and validation links that + # will be emailed to users if config.get("public_baseurl") is None: missing.append("public_baseurl") - if len(missing) > 0: - raise RuntimeError( - "Password resets emails are configured to be sent from " - "this homeserver due to a partial 'email' block. " - "However, the following required keys are missing: %s" - % (", ".join(missing),) + if missing: + raise ConfigError( + MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) ) # These email templates have placeholders in them, and thus must be @@ -169,12 +181,22 @@ class EmailConfig(Config): self.email_registration_template_text = email_config.get( "registration_template_text", "registration.txt" ) + self.email_add_threepid_template_html = email_config.get( + "add_threepid_template_html", "add_threepid.html" + ) + self.email_add_threepid_template_text = email_config.get( + "add_threepid_template_text", "add_threepid.txt" + ) + self.email_password_reset_template_failure_html = email_config.get( "password_reset_template_failure_html", "password_reset_failure.html" ) self.email_registration_template_failure_html = email_config.get( "registration_template_failure_html", "registration_failure.html" ) + self.email_add_threepid_template_failure_html = email_config.get( + "add_threepid_template_failure_html", "add_threepid_failure.html" + ) # These templates do not support any placeholder variables, so we # will read them from disk once during setup @@ -184,6 +206,9 @@ class EmailConfig(Config): email_registration_template_success_html = email_config.get( "registration_template_success_html", "registration_success.html" ) + email_add_threepid_template_success_html = email_config.get( + "add_threepid_template_success_html", "add_threepid_success.html" + ) # Check templates exist for f in [ @@ -191,9 +216,14 @@ class EmailConfig(Config): self.email_password_reset_template_text, self.email_registration_template_html, self.email_registration_template_text, + self.email_add_threepid_template_html, + self.email_add_threepid_template_text, self.email_password_reset_template_failure_html, + self.email_registration_template_failure_html, + self.email_add_threepid_template_failure_html, email_password_reset_template_success_html, email_registration_template_success_html, + email_add_threepid_template_success_html, ]: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): @@ -212,34 +242,33 @@ class EmailConfig(Config): self.email_registration_template_success_html_content = self.read_file( filepath, "email.registration_template_success_html" ) + filepath = os.path.join( + self.email_template_dir, email_add_threepid_template_success_html + ) + self.email_add_threepid_template_success_html_content = self.read_file( + filepath, "email.add_threepid_template_success_html" + ) if self.email_enable_notifs: - required = [ - "smtp_host", - "smtp_port", - "notif_from", - "notif_template_html", - "notif_template_text", - ] - missing = [] - for k in required: - if k not in email_config: - missing.append(k) - - if len(missing) > 0: - raise RuntimeError( - "email.enable_notifs is True but required keys are missing: %s" - % (", ".join(["email." + k for k in missing]),) - ) + if not self.email_notif_from: + missing.append("email.notif_from") if config.get("public_baseurl") is None: - raise RuntimeError( - "email.enable_notifs is True but no public_baseurl is set" + missing.append("public_baseurl") + + if missing: + raise ConfigError( + "email.enable_notifs is True but required keys are missing: %s" + % (", ".join(missing),) ) - self.email_notif_template_html = email_config["notif_template_html"] - self.email_notif_template_text = email_config["notif_template_text"] + self.email_notif_template_html = email_config.get( + "notif_template_html", "notif_mail.html" + ) + self.email_notif_template_text = email_config.get( + "notif_template_text", "notif_mail.txt" + ) for f in self.email_notif_template_text, self.email_notif_template_html: p = os.path.join(self.email_template_dir, f) @@ -249,7 +278,9 @@ class EmailConfig(Config): self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) - self.email_riot_base_url = email_config.get("riot_base_url", None) + self.email_riot_base_url = email_config.get( + "client_base_url", email_config.get("riot_base_url", None) + ) if account_validity_renewal_enabled: self.email_expiry_template_html = email_config.get( @@ -265,80 +296,112 @@ class EmailConfig(Config): raise ConfigError("Unable to find email template file %s" % (p,)) def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """ - # Enable sending emails for password resets, notification events or - # account expiry notices - # - # If your SMTP server requires authentication, the optional smtp_user & - # smtp_pass variables should be used - # - #email: - # enable_notifs: false - # smtp_host: "localhost" - # smtp_port: 25 # SSL: 465, STARTTLS: 587 - # smtp_user: "exampleusername" - # smtp_pass: "examplepassword" - # require_transport_security: False - # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>" - # app_name: Matrix - # - # # Enable email notifications by default - # # - # notif_for_new_users: True - # - # # Defining a custom URL for Riot is only needed if email notifications - # # should contain links to a self-hosted installation of Riot; when set - # # the "app_name" setting is ignored - # # - # riot_base_url: "http://localhost/riot" - # - # # Configure the time that a validation email or text message code - # # will expire after sending - # # - # # This is currently used for password resets - # # - # #validation_token_lifetime: 1h - # - # # Template directory. All template files should be stored within this - # # directory. If not set, default templates from within the Synapse - # # package will be used - # # - # # For the list of default templates, please see - # # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # # - # #template_dir: res/templates - # - # # Templates for email notifications - # # - # notif_template_html: notif_mail.html - # notif_template_text: notif_mail.txt - # - # # Templates for account expiry notices - # # - # expiry_template_html: notice_expiry.html - # expiry_template_text: notice_expiry.txt - # - # # Templates for password reset emails sent by the homeserver - # # - # #password_reset_template_html: password_reset.html - # #password_reset_template_text: password_reset.txt - # - # # Templates for registration emails sent by the homeserver - # # - # #registration_template_html: registration.html - # #registration_template_text: registration.txt - # - # # Templates for password reset success and failure pages that a user - # # will see after attempting to reset their password - # # - # #password_reset_template_success_html: password_reset_success.html - # #password_reset_template_failure_html: password_reset_failure.html + return """\ + # Configuration for sending emails from Synapse. # - # # Templates for registration success and failure pages that a user - # # will see after attempting to register using an email or phone - # # - # #registration_template_success_html: registration_success.html - # #registration_template_failure_html: registration_failure.html + email: + # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. + # + #smtp_host: mail.server + + # The port on the mail server for outgoing SMTP. Defaults to 25. + # + #smtp_port: 587 + + # Username/password for authentication to the SMTP server. By default, no + # authentication is attempted. + # + #smtp_user: "exampleusername" + #smtp_pass: "examplepassword" + + # Uncomment the following to require TLS transport security for SMTP. + # By default, Synapse will connect over plain text, and will then switch to + # TLS via STARTTLS *if the SMTP server supports it*. If this option is set, + # Synapse will refuse to connect unless the server supports STARTTLS. + # + #require_transport_security: true + + # notif_from defines the "From" address to use when sending emails. + # It must be set if email sending is enabled. + # + # The placeholder '%(app)s' will be replaced by the application name, + # which is normally 'app_name' (below), but may be overridden by the + # Matrix client application. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" + + # app_name defines the default value for '%(app)s' in notif_from. It + # defaults to 'Matrix'. + # + #app_name: my_branded_matrix_server + + # Uncomment the following to enable sending emails for messages that the user + # has missed. Disabled by default. + # + #enable_notifs: true + + # Uncomment the following to disable automatic subscription to email + # notifications for new users. Enabled by default. + # + #notif_for_new_users: false + + # Custom URL for client links within the email notifications. By default + # links will be based on "https://matrix.to". + # + # (This setting used to be called riot_base_url; the old name is still + # supported for backwards-compatibility but is now deprecated.) + # + #client_base_url: "http://localhost/riot" + + # Configure the time that a validation email will expire after sending. + # Defaults to 1h. + # + #validation_token_lifetime: 15m + + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * The contents of email notifications of missed events: 'notif_mail.html' and + # 'notif_mail.txt'. + # + # * The contents of account expiry notice emails: 'notice_expiry.html' and + # 'notice_expiry.txt'. + # + # * The contents of password reset emails sent by the homeserver: + # 'password_reset.html' and 'password_reset.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in the password reset email: 'password_reset_success.html' and + # 'password_reset_failure.html' + # + # * The contents of address verification emails sent during registration: + # 'registration.html' and 'registration.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in an address verification email sent during registration: + # 'registration_success.html' and 'registration_failure.html' + # + # * The contents of address verification emails sent when an address is added + # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in an address verification email sent when an address is added + # to a Matrix account: 'add_threepid_success.html' and + # 'add_threepid_failure.html' + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" """ diff --git a/synapse/config/groups.py b/synapse/config/groups.py index 2a522b5f44..d6862d9a64 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py @@ -17,6 +17,8 @@ from ._base import Config class GroupsConfig(Config): + section = "groups" + def read_config(self, config, **kwargs): self.enable_group_creation = config.get("enable_group_creation", False) self.group_creation_prefix = config.get("group_creation_prefix", "") diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 72acad4f18..2c7b3a699f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -14,8 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig from .consent_config import ConsentConfig @@ -26,10 +28,12 @@ from .jwt_config import JWTConfig from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig +from .oidc_config import OIDCConfig from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig +from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig from .room_directory import RoomDirectoryConfig @@ -37,6 +41,7 @@ from .saml2_config import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig from .spam_checker import SpamCheckerConfig +from .sso import SSOConfig from .stats import StatsConfig from .third_party_event_rules import ThirdPartyRulesConfig from .tls import TlsConfig @@ -46,36 +51,41 @@ from .voip import VoipConfig from .workers import WorkerConfig -class HomeServerConfig( - ServerConfig, - TlsConfig, - DatabaseConfig, - LoggingConfig, - RatelimitConfig, - ContentRepositoryConfig, - CaptchaConfig, - VoipConfig, - RegistrationConfig, - MetricsConfig, - ApiConfig, - AppServiceConfig, - KeyConfig, - SAML2Config, - CasConfig, - JWTConfig, - PasswordConfig, - EmailConfig, - WorkerConfig, - PasswordAuthProviderConfig, - PushConfig, - SpamCheckerConfig, - GroupsConfig, - UserDirectoryConfig, - ConsentConfig, - StatsConfig, - ServerNoticesConfig, - RoomDirectoryConfig, - ThirdPartyRulesConfig, - TracerConfig, -): - pass +class HomeServerConfig(RootConfig): + + config_classes = [ + ServerConfig, + TlsConfig, + CacheConfig, + DatabaseConfig, + LoggingConfig, + RatelimitConfig, + ContentRepositoryConfig, + CaptchaConfig, + VoipConfig, + RegistrationConfig, + MetricsConfig, + ApiConfig, + AppServiceConfig, + KeyConfig, + SAML2Config, + OIDCConfig, + CasConfig, + SSOConfig, + JWTConfig, + PasswordConfig, + EmailConfig, + WorkerConfig, + PasswordAuthProviderConfig, + PushConfig, + SpamCheckerConfig, + GroupsConfig, + UserDirectoryConfig, + ConsentConfig, + StatsConfig, + ServerNoticesConfig, + RoomDirectoryConfig, + ThirdPartyRulesConfig, + TracerConfig, + RedisConfig, + ] diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index 36d87cef03..a568726985 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -23,6 +23,8 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login. class JWTConfig(Config): + section = "jwt" + def read_config(self, config, **kwargs): jwt_config = config.get("jwt_config", None) if jwt_config: diff --git a/synapse/config/key.py b/synapse/config/key.py index ba2199bceb..b529ea5da0 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -50,6 +50,33 @@ and you should enable 'federation_verify_certificates' in your configuration. If you are *sure* you want to do this, set 'accept_keys_insecurely' on the trusted_key_server configuration.""" +TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN = """\ +Synapse requires that a list of trusted key servers are specified in order to +provide signing keys for other servers in the federation. + +This homeserver does not have a trusted key server configured in +homeserver.yaml and will fall back to the default of 'matrix.org'. + +Trusted key servers should be long-lived and stable which makes matrix.org a +good choice for many admins, but some admins may wish to choose another. To +suppress this warning, the admin should set 'trusted_key_servers' in +homeserver.yaml to their desired key server and 'suppress_key_server_warning' +to 'true'. + +In a future release the software-defined default will be removed entirely and +the trusted key server will be defined exclusively by the value of +'trusted_key_servers'. +--------------------------------------------------------------------------------""" + +TRUSTED_KEY_SERVER_CONFIGURED_AS_M_ORG_WARN = """\ +This server is configured to use 'matrix.org' as its trusted key server via the +'trusted_key_servers' config option. 'matrix.org' is a good choice for a key +server since it is long-lived, stable and trusted. However, some admins may +wish to use another server for this purpose. + +To suppress this warning and continue using 'matrix.org', admins should set +'suppress_key_server_warning' to 'true' in homeserver.yaml. +--------------------------------------------------------------------------------""" logger = logging.getLogger(__name__) @@ -65,6 +92,8 @@ class TrustedKeyServer(object): class KeyConfig(Config): + section = "key" + def read_config(self, config, config_dir_path, **kwargs): # the signing key can be specified inline or in a separate file if "signing_key" in config: @@ -79,12 +108,13 @@ class KeyConfig(Config): self.signing_key = self.read_signing_keys(signing_key_path, "signing_key") self.old_signing_keys = self.read_old_signing_keys( - config.get("old_signing_keys", {}) + config.get("old_signing_keys") ) self.key_refresh_interval = self.parse_duration( config.get("key_refresh_interval", "1d") ) + suppress_key_server_warning = config.get("suppress_key_server_warning", False) key_server_signing_keys_path = config.get("key_server_signing_keys_path") if key_server_signing_keys_path: self.key_server_signing_keys = self.read_signing_keys( @@ -95,6 +125,7 @@ class KeyConfig(Config): # if neither trusted_key_servers nor perspectives are given, use the default. if "perspectives" not in config and "trusted_key_servers" not in config: + logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) key_servers = [{"server_name": "matrix.org"}] else: key_servers = config.get("trusted_key_servers", []) @@ -108,6 +139,11 @@ class KeyConfig(Config): # merge the 'perspectives' config into the 'trusted_key_servers' config. key_servers.extend(_perspectives_to_key_servers(config)) + if not suppress_key_server_warning and "matrix.org" in ( + s["server_name"] for s in key_servers + ): + logger.warning(TRUSTED_KEY_SERVER_CONFIGURED_AS_M_ORG_WARN) + # list of TrustedKeyServer objects self.key_servers = list( _parse_key_servers(key_servers, self.federation_verify_certificates) @@ -120,7 +156,7 @@ class KeyConfig(Config): if not self.macaroon_secret_key: # Unfortunately, there are people out there that don't have this # set. Lets just be "nice" and derive one from their secret key. - logger.warn("Config is missing macaroon_secret_key") + logger.warning("Config is missing macaroon_secret_key") seed = bytes(self.signing_key[0]) self.macaroon_secret_key = hashlib.sha256(seed).digest() @@ -139,8 +175,8 @@ class KeyConfig(Config): ) form_secret = 'form_secret: "%s"' % random_string_with_symbols(50) else: - macaroon_secret_key = "# macaroon_secret_key: <PRIVATE STRING>" - form_secret = "# form_secret: <PRIVATE STRING>" + macaroon_secret_key = "#macaroon_secret_key: <PRIVATE STRING>" + form_secret = "#form_secret: <PRIVATE STRING>" return ( """\ @@ -163,14 +199,19 @@ class KeyConfig(Config): signing_key_path: "%(base_key_name)s.signing.key" # The keys that the server used to sign messages with but won't use - # to sign new messages. E.g. it has lost its private key + # to sign new messages. # - #old_signing_keys: - # "ed25519:auto": - # # Base64 encoded public key - # key: "The public part of your old signing key." - # # Millisecond POSIX timestamp when the key expired. - # expired_ts: 123456789123 + old_signing_keys: + # For each key, `key` should be the base64-encoded public key, and + # `expired_ts`should be the time (in milliseconds since the unix epoch) that + # it was last used. + # + # It is possible to build an entry from an old signing.key file using the + # `export_signing_key` script which is provided with synapse. + # + # For example: + # + #"ed25519:id": { key: "base64string", expired_ts: 123456789123 } # How long key response published by this server is valid for. # Used to set the valid_until_ts in /key/v2 APIs. @@ -190,6 +231,10 @@ class KeyConfig(Config): # This setting supercedes an older setting named `perspectives`. The old format # is still supported for backwards-compatibility, but it is deprecated. # + # 'trusted_key_servers' defaults to matrix.org, but using it will generate a + # warning on start-up. To suppress this warning, set + # 'suppress_key_server_warning' to true. + # # Options for each entry in the list include: # # server_name: the name of the server. required. @@ -214,11 +259,13 @@ class KeyConfig(Config): # "ed25519:auto": "abcdefghijklmnopqrstuvwxyzabcdefghijklmopqr" # - server_name: "my_other_trusted_server.example.com" # - # The default configuration is: - # - #trusted_key_servers: - # - server_name: "matrix.org" + trusted_key_servers: + - server_name: "matrix.org" + + # Uncomment the following to disable the warning that is emitted when the + # trusted_key_servers include 'matrix.org'. See above. # + #suppress_key_server_warning: true # The signing keys to use when acting as a trusted key server. If not specified # defaults to the server signing key. @@ -248,6 +295,8 @@ class KeyConfig(Config): raise ConfigError("Error reading %s: %s" % (name, str(e))) def read_old_signing_keys(self, old_signing_keys): + if old_signing_keys is None: + return {} keys = {} for key_id, key_data in old_signing_keys.items(): if is_signing_algorithm_supported(key_id): diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 767ecfdf09..49f6c32beb 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import argparse import logging import logging.config import os @@ -37,10 +37,17 @@ from synapse.logging._structured import ( from synapse.logging.context import LoggingContextFilter from synapse.util.versionstring import get_version_string -from ._base import Config +from ._base import Config, ConfigError DEFAULT_LOG_CONFIG = Template( - """ + """\ +# Log configuration for Synapse. +# +# This is a YAML file containing a standard Python logging configuration +# dictionary. See [1] for details on the valid settings. +# +# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema + version: 1 formatters: @@ -68,9 +75,6 @@ handlers: filters: [context] loggers: - synapse: - level: INFO - synapse.storage.SQL: # beware: increasing this to DEBUG will make synapse log sensitive # information such as access tokens. @@ -79,12 +83,23 @@ loggers: root: level: INFO handlers: [file, console] + +disable_existing_loggers: false """ ) +LOG_FILE_ERROR = """\ +Support for the log_file configuration option and --log-file command-line option was +removed in Synapse 1.3.0. You should instead set up a separate log configuration file. +""" + class LoggingConfig(Config): + section = "logging" + def read_config(self, config, **kwargs): + if config.get("log_file"): + raise ConfigError(LOG_FILE_ERROR) self.log_config = self.abspath(config.get("log_config")) self.no_redirect_stdio = config.get("no_redirect_stdio", False) @@ -105,6 +120,8 @@ class LoggingConfig(Config): def read_arguments(self, args): if args.no_redirect_stdio is not None: self.no_redirect_stdio = args.no_redirect_stdio + if args.log_file is not None: + raise ConfigError(LOG_FILE_ERROR) @staticmethod def add_arguments(parser): @@ -117,6 +134,10 @@ class LoggingConfig(Config): help="Do not redirect stdout/stderr to the log", ) + logging_group.add_argument( + "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS, + ) + def generate_files(self, config, config_dir_path): log_config = config.get("log_config") if log_config and not os.path.exists(log_config): @@ -181,7 +202,7 @@ def _reload_stdlib_logging(*args, log_config=None): logger = logging.getLogger("") if not log_config: - logger.warn("Reloaded a blank config?") + logger.warning("Reloaded a blank config?") logging.config.dictConfig(log_config) @@ -233,8 +254,9 @@ def setup_logging( # make sure that the first thing we log is a thing we can grep backwards # for - logging.warn("***** STARTING SERVER *****") - logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) + logging.warning("***** STARTING SERVER *****") + logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) + logging.info("Instance name: %s", hs.get_instance_name()) return logger diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index ec35a6b868..6aad0d37c0 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -34,6 +34,8 @@ class MetricsFlags(object): class MetricsConfig(Config): + section = "metrics" + def read_config(self, config, **kwargs): self.enable_metrics = config.get("enable_metrics", False) self.report_stats = config.get("report_stats", None) @@ -68,7 +70,7 @@ class MetricsConfig(Config): # Enable collection and rendering of performance metrics # - #enable_metrics: False + #enable_metrics: false # Enable sentry integration # NOTE: While attempts are made to ensure that the logs don't contain @@ -84,17 +86,18 @@ class MetricsConfig(Config): # enabled by default, either for performance reasons or limited use. # metrics_flags: - # Publish synapse_federation_known_servers, a g auge of the number of + # Publish synapse_federation_known_servers, a gauge of the number of # servers this homeserver knows about, including itself. May cause # performance problems on large homeservers. # #known_servers: true # Whether or not to report anonymized homeserver usage statistics. + # """ if report_stats is None: - res += "# report_stats: true|false\n" + res += "#report_stats: true|false\n" else: res += "report_stats: %s\n" % ("true" if report_stats else "false") diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py new file mode 100644 index 0000000000..e24dd637bc --- /dev/null +++ b/synapse/config/oidc_config.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.module_loader import load_module + +from ._base import Config, ConfigError + +DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" + + +class OIDCConfig(Config): + section = "oidc" + + def read_config(self, config, **kwargs): + self.oidc_enabled = False + + oidc_config = config.get("oidc_config") + + if not oidc_config or not oidc_config.get("enabled", False): + return + + try: + check_requirements("oidc") + except DependencyException as e: + raise ConfigError(e.message) + + public_baseurl = self.public_baseurl + if public_baseurl is None: + raise ConfigError("oidc_config requires a public_baseurl to be set") + self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" + + self.oidc_enabled = True + self.oidc_discover = oidc_config.get("discover", True) + self.oidc_issuer = oidc_config["issuer"] + self.oidc_client_id = oidc_config["client_id"] + self.oidc_client_secret = oidc_config["client_secret"] + self.oidc_client_auth_method = oidc_config.get( + "client_auth_method", "client_secret_basic" + ) + self.oidc_scopes = oidc_config.get("scopes", ["openid"]) + self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") + self.oidc_token_endpoint = oidc_config.get("token_endpoint") + self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") + self.oidc_jwks_uri = oidc_config.get("jwks_uri") + self.oidc_skip_verification = oidc_config.get("skip_verification", False) + + ump_config = oidc_config.get("user_mapping_provider", {}) + ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + ump_config.setdefault("config", {}) + + ( + self.oidc_user_mapping_provider_class, + self.oidc_user_mapping_provider_config, + ) = load_module(ump_config) + + # Ensure loaded user mapping module has defined all necessary methods + required_methods = [ + "get_remote_user_id", + "map_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.oidc_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by oidc_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # OpenID Connect integration. The following settings can be used to make Synapse + # use an OpenID Connect Provider for authentication, instead of its internal + # password database. + # + # See https://github.com/matrix-org/synapse/blob/master/openid.md. + # + oidc_config: + # Uncomment the following to enable authorization against an OpenID Connect + # server. Defaults to false. + # + #enabled: true + + # Uncomment the following to disable use of the OIDC discovery mechanism to + # discover endpoints. Defaults to true. + # + #discover: false + + # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to + # discover the provider's endpoints. + # + # Required if 'enabled' is true. + # + #issuer: "https://accounts.example.com/" + + # oauth2 client id to use. + # + # Required if 'enabled' is true. + # + #client_id: "provided-by-your-issuer" + + # oauth2 client secret to use. + # + # Required if 'enabled' is true. + # + #client_secret: "provided-by-your-issuer" + + # auth method to use when exchanging the token. + # Valid values are 'client_secret_basic' (default), 'client_secret_post' and + # 'none'. + # + #client_auth_method: client_secret_post + + # list of scopes to request. This should normally include the "openid" scope. + # Defaults to ["openid"]. + # + #scopes: ["openid", "profile"] + + # the oauth2 authorization endpoint. Required if provider discovery is disabled. + # + #authorization_endpoint: "https://accounts.example.com/oauth2/auth" + + # the oauth2 token endpoint. Required if provider discovery is disabled. + # + #token_endpoint: "https://accounts.example.com/oauth2/token" + + # the OIDC userinfo endpoint. Required if discovery is disabled and the + # "openid" scope is not requested. + # + #userinfo_endpoint: "https://accounts.example.com/userinfo" + + # URI where to fetch the JWKS. Required if discovery is disabled and the + # "openid" scope is used. + # + #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + + # Uncomment to skip metadata verification. Defaults to false. + # + # Use this if you are connecting to a provider that is not OpenID Connect + # compliant. + # Avoid this in production. + # + #skip_verification: true + + # An external module can be provided here as a custom solution to mapping + # attributes returned from a OIDC provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # Default is {mapping_provider!r}. + # + # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # for information on implementing a custom mapping provider. + # + #module: mapping_provider.OidcMappingProvider + + # Custom configuration values for the module. This section will be passed as + # a Python dictionary to the user mapping provider module's `parse_config` + # method. + # + # The examples below are intended for the default provider: they should be + # changed if using a custom provider. + # + config: + # name of the claim containing a unique identifier for the user. + # Defaults to `sub`, which OpenID Connect compliant providers should provide. + # + #subject_claim: "sub" + + # Jinja2 template for the localpart of the MXID. + # + # When rendering, this template is given the following variables: + # * user: The claims returned by the UserInfo Endpoint and/or in the ID + # Token + # + # This must be configured if using the default mapping provider. + # + localpart_template: "{{{{ user.preferred_username }}}}" + + # Jinja2 template for the display name to set on first login. + # + # If unset, no displayname will be set. + # + #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + """.format( + mapping_provider=DEFAULT_USER_MAPPING_PROVIDER + ) diff --git a/synapse/config/password.py b/synapse/config/password.py index d5b5953f2f..9c0ea8c30a 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -20,6 +20,8 @@ class PasswordConfig(Config): """Password login configuration """ + section = "password" + def read_config(self, config, **kwargs): password_config = config.get("password_config", {}) if password_config is None: @@ -29,6 +31,10 @@ class PasswordConfig(Config): self.password_localdb_enabled = password_config.get("localdb_enabled", True) self.password_pepper = password_config.get("pepper", "") + # Password policy + self.password_policy = password_config.get("policy") or {} + self.password_policy_enabled = self.password_policy.get("enabled", False) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ password_config: @@ -46,4 +52,39 @@ class PasswordConfig(Config): # DO NOT CHANGE THIS AFTER INITIAL SETUP! # #pepper: "EVEN_MORE_SECRET" + + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true """ diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 788c39c9fb..4fda8ae987 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List + from synapse.util.module_loader import load_module from ._base import Config @@ -21,8 +23,10 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" class PasswordAuthProviderConfig(Config): + section = "authproviders" + def read_config(self, config, **kwargs): - self.password_providers = [] + self.password_providers = [] # type: List[Any] providers = [] # We want to be backwards compatible with the old `ldap_config` @@ -31,7 +35,7 @@ class PasswordAuthProviderConfig(Config): if ldap_config.get("enabled", False): providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) - providers.extend(config.get("password_providers", [])) + providers.extend(config.get("password_providers") or []) for provider in providers: mod_name = provider["module"] @@ -48,7 +52,19 @@ class PasswordAuthProviderConfig(Config): def generate_config_section(self, **kwargs): return """\ - #password_providers: + # Password providers allow homeserver administrators to integrate + # their Synapse installation with existing authentication methods + # ex. LDAP, external tokens, etc. + # + # For more information and known implementations, please see + # https://github.com/matrix-org/synapse/blob/master/docs/password_auth_providers.md + # + # Note: instances wishing to use SAML or CAS authentication should + # instead use the `saml2_config` or `cas_config` options, + # respectively. + # + password_providers: + # # Example config for an LDAP auth provider # - module: "ldap_auth_provider.LdapAuthProvider" # config: # enabled: true diff --git a/synapse/config/push.py b/synapse/config/push.py index 1b932722a5..6f2b3a7faa 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -18,6 +18,8 @@ from ._base import Config class PushConfig(Config): + section = "push" + def read_config(self, config, **kwargs): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) @@ -33,7 +35,7 @@ class PushConfig(Config): # Now check for the one in the 'email' section and honour it, # with a warning. - push_config = config.get("email", {}) + push_config = config.get("email") or {} redact_content = push_config.get("redact_content") if redact_content is not None: print( diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 587e2862b7..2dd94bae2b 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + from ._base import Config class RateLimitConfig(object): - def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): + def __init__( + self, + config: Dict[str, float], + defaults={"per_second": 0.17, "burst_count": 3.0}, + ): self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = config.get("burst_count", defaults["burst_count"]) @@ -36,6 +42,8 @@ class FederationRateLimitConfig(object): class RatelimitConfig(Config): + section = "ratelimiting" + def read_config(self, config, **kwargs): # Load the new-style messages config if it exists. Otherwise fall back @@ -81,10 +89,9 @@ class RatelimitConfig(Config): ) rc_admin_redaction = config.get("rc_admin_redaction") + self.rc_admin_redaction = None if rc_admin_redaction: self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) - else: - self.rc_admin_redaction = None def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/redis.py b/synapse/config/redis.py new file mode 100644 index 0000000000..d5d3ca1c9e --- /dev/null +++ b/synapse/config/redis.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.python_dependencies import check_requirements + + +class RedisConfig(Config): + section = "redis" + + def read_config(self, config, **kwargs): + redis_config = config.get("redis", {}) + self.redis_enabled = redis_config.get("enabled", False) + + if not self.redis_enabled: + return + + check_requirements("redis") + + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_password = redis_config.get("password") diff --git a/synapse/config/registration.py b/synapse/config/registration.py index d4654e99b3..fecced2d57 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -24,7 +24,12 @@ from synapse.util.stringutils import random_string_with_symbols class AccountValidityConfig(Config): + section = "accountvalidity" + def __init__(self, config, synapse_config): + if config is None: + return + super(AccountValidityConfig, self).__init__() self.enabled = config.get("enabled", False) self.renew_by_email_enabled = "renew_at" in config @@ -77,6 +82,8 @@ class AccountValidityConfig(Config): class RegistrationConfig(Config): + section = "registration" + def read_config(self, config, **kwargs): self.enable_registration = bool( strtobool(str(config.get("enable_registration", False))) @@ -87,7 +94,7 @@ class RegistrationConfig(Config): ) self.account_validity = AccountValidityConfig( - config.get("account_validity", {}), config + config.get("account_validity") or {}, config ) self.registrations_require_3pid = config.get("registrations_require_3pid", []) @@ -102,6 +109,13 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if self.account_threepid_delegate_msisdn and not self.public_baseurl: + raise ConfigError( + "The configuration option `public_baseurl` is required if " + "`account_threepid_delegate.msisdn` is set, such that " + "clients know where to submit validation tokens to. Please " + "configure `public_baseurl`." + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -114,6 +128,11 @@ class RegistrationConfig(Config): if not RoomAlias.is_valid(room_alias): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True) + + self.enable_set_displayname = config.get("enable_set_displayname", True) + self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + self.enable_3pid_changes = config.get("enable_3pid_changes", True) self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False @@ -130,9 +149,7 @@ class RegistrationConfig(Config): random_string_with_symbols(50), ) else: - registration_shared_secret = ( - "# registration_shared_secret: <PRIVATE STRING>" - ) + registration_shared_secret = "#registration_shared_secret: <PRIVATE STRING>" return ( """\ @@ -148,23 +165,6 @@ class RegistrationConfig(Config): # Optional account validity configuration. This allows for accounts to be denied # any request after a given period. # - # ``enabled`` defines whether the account validity feature is enabled. Defaults - # to False. - # - # ``period`` allows setting the period after which an account is valid - # after its registration. When renewing the account, its validity period - # will be extended by this amount of time. This parameter is required when using - # the account validity feature. - # - # ``renew_at`` is the amount of time before an account's expiry date at which - # Synapse will send an email to the account's email address with a renewal link. - # This needs the ``email`` and ``public_baseurl`` configuration sections to be - # filled. - # - # ``renew_email_subject`` is the subject of the email sent out with the renewal - # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter - # from the ``email`` section. - # # Once this feature is enabled, Synapse will look for registered users without an # expiration date at startup and will add one to every account it found using the # current settings at that time. @@ -175,21 +175,55 @@ class RegistrationConfig(Config): # date will be randomly selected within a range [now + period - d ; now + period], # where d is equal to 10%% of the validity period. # - #account_validity: - # enabled: True - # period: 6w - # renew_at: 1w - # renew_email_subject: "Renew your %%(app)s account" - # # Directory in which Synapse will try to find the HTML files to serve to the - # # user when trying to renew an account. Optional, defaults to - # # synapse/res/templates. - # template_dir: "res/templates" - # # HTML to be displayed to the user after they successfully renewed their - # # account. Optional. - # account_renewed_html_path: "account_renewed.html" - # # HTML to be displayed when the user tries to renew an account with an invalid - # # renewal token. Optional. - # invalid_token_html_path: "invalid_token.html" + account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %%(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + #template_dir: "res/templates" + + # File within 'template_dir' giving the HTML to be displayed to the user after + # they successfully renewed their account. If not set, default text is used. + # + #account_renewed_html_path: "account_renewed.html" + + # File within 'template_dir' giving the HTML to be displayed when the user + # tries to renew an account with an invalid renewal token. If not set, + # default text is used. + # + #invalid_token_html_path: "invalid_token.html" # Time that a user's session remains valid for, after they log in. # @@ -293,10 +327,35 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # + # If a delegate is specified, the config option public_baseurl must also be filled out. + # account_threepid_delegates: - #email: https://example.com # Delegate email sending to example.org + #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process + # Whether users are allowed to change their displayname after it has + # been initially set. Useful when provisioning users based on the + # contents of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_displayname: false + + # Whether users are allowed to change their avatar after it has been + # initially set. Useful when provisioning users based on the contents + # of a third-party directory. + # + # Does not apply to server administrators. Defaults to 'true' + # + #enable_set_avatar_url: false + + # Whether users can change the 3PIDs associated with their accounts + # (email address and msisdn). + # + # Defaults to 'true' + # + #enable_3pid_changes: false + # Users who register on this homeserver will automatically be joined # to these rooms # @@ -310,6 +369,13 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # #autocreate_auto_join_rooms: true + + # When auto_join_rooms is specified, setting this flag to false prevents + # guest accounts from being automatically joined to the rooms. + # + # Defaults to true. + # + #auto_join_rooms_for_guests: false """ % locals() ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 34f1a9a92d..b751d02d37 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 matrix.org +# Copyright 2014, 2015 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ import os from collections import namedtuple +from typing import Dict, List from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module @@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): Dictionary mapping from media type string to list of ThumbnailRequirement tuples. """ - requirements = {} + requirements = {} # type: Dict[str, List] for size in thumbnail_sizes: width = size["width"] height = size["height"] @@ -69,6 +70,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg") png_thumbnail = ThumbnailRequirement(width, height, method, "image/png") requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail) + requirements.setdefault("image/webp", []).append(jpeg_thumbnail) requirements.setdefault("image/gif", []).append(png_thumbnail) requirements.setdefault("image/png", []).append(png_thumbnail) return { @@ -77,6 +79,8 @@ def parse_thumbnail_requirements(thumbnail_sizes): class ContentRepositoryConfig(Config): + section = "media" + def read_config(self, config, **kwargs): # Only enable the media repo if either the media repo is enabled or the @@ -130,7 +134,7 @@ class ContentRepositoryConfig(Config): # # We don't create the storage providers here as not all workers need # them to be started. - self.media_storage_providers = [] + self.media_storage_providers = [] # type: List[tuple] for provider_config in storage_providers: # We special case the module "file_system" so as not to need to @@ -153,7 +157,6 @@ class ContentRepositoryConfig(Config): (provider_class, parsed_config, wrapper_config) ) - self.uploads_path = self.ensure_directory(config.get("uploads_path", "uploads")) self.dynamic_thumbnails = config.get("dynamic_thumbnails", False) self.thumbnail_requirements = parse_thumbnail_requirements( config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES) @@ -190,6 +193,10 @@ class ContentRepositoryConfig(Config): self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) + self.url_preview_accept_language = config.get( + "url_preview_accept_language" + ) or ["en"] + def generate_config_section(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") uploads_path = os.path.join(data_dir_path, "uploads") @@ -218,20 +225,15 @@ class ContentRepositoryConfig(Config): # #media_storage_providers: # - module: file_system - # # Whether to write new local files. + # # Whether to store newly uploaded local files # store_local: false - # # Whether to write new remote media + # # Whether to store newly downloaded remote files # store_remote: false - # # Whether to block upload requests waiting for write to this - # # provider to complete + # # Whether to wait for successful storage for local uploads # store_synchronous: false # config: # directory: /mnt/some/other/directory - # Directory where in-progress uploads are stored. - # - uploads_path: "%(uploads_path)s" - # The largest allowed upload size in bytes # #max_upload_size: 10M @@ -331,6 +333,31 @@ class ContentRepositoryConfig(Config): # The largest allowed URL preview spidering size in bytes # #max_spider_size: 10M + + # A list of values for the Accept-Language HTTP header used when + # downloading webpages during URL preview generation. This allows + # Synapse to specify the preferred languages that URL previews should + # be in when communicating with remote servers. + # + # Each value is a IETF language tag; a 2-3 letter identifier for a + # language, optionally followed by subtags separated by '-', specifying + # a country or region variant. + # + # Multiple values can be provided, and a weight can be added to each by + # using quality value syntax (;q=). '*' translates to any language. + # + # Defaults to "en". + # + # Example: + # + # url_preview_accept_language: + # - en-UK + # - en-US;q=0.9 + # - fr;q=0.8 + # - *;q=0.7 + # + url_preview_accept_language: + # - en """ % locals() ) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index a92693017b..7ac7699676 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -19,6 +19,8 @@ from ._base import Config, ConfigError class RoomDirectoryConfig(Config): + section = "roomdirectory" + def read_config(self, config, **kwargs): self.enable_room_list_search = config.get("enable_room_list_search", True) @@ -168,7 +170,7 @@ class _RoomDirectoryRule(object): self.action = action else: raise ConfigError( - "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,) + "%s rules can only have action of 'allow' or 'deny'" % (option_name,) ) self._alias_matches_all = alias == "*" diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 6a8161547a..d0a19751e8 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +13,55 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import logging + +import jinja2 +import pkg_resources + from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + +DEFAULT_USER_MAPPING_PROVIDER = ( + "synapse.handlers.saml_handler.DefaultSamlMappingProvider" +) + + +def _dict_merge(merge_dict, into_dict): + """Do a deep merge of two dicts + + Recursively merges `merge_dict` into `into_dict`: + * For keys where both `merge_dict` and `into_dict` have a dict value, the values + are recursively merged + * For all other keys, the values in `into_dict` (if any) are overwritten with + the value from `merge_dict`. + + Args: + merge_dict (dict): dict to merge + into_dict (dict): target dict + """ + for k, v in merge_dict.items(): + if k not in into_dict: + into_dict[k] = v + continue + + current_val = into_dict[k] + + if isinstance(v, dict) and isinstance(current_val, dict): + _dict_merge(v, current_val) + continue + + # otherwise we just overwrite + into_dict[k] = v + class SAML2Config(Config): + section = "saml2" + def read_config(self, config, **kwargs): self.saml2_enabled = False @@ -26,6 +70,9 @@ class SAML2Config(Config): if not saml2_config or not saml2_config.get("enabled", True): return + if not saml2_config.get("sp_config") and not saml2_config.get("config_path"): + return + try: check_requirements("saml2") except DependencyException as e: @@ -33,28 +80,124 @@ class SAML2Config(Config): self.saml2_enabled = True - import saml2.config + self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( + "grandfathered_mxid_source_attribute", "uid" + ) - self.saml2_sp_config = saml2.config.SPConfig() - self.saml2_sp_config.load(self._default_saml_config_dict()) - self.saml2_sp_config.load(saml2_config.get("sp_config", {})) + # user_mapping_provider may be None if the key is present but has no value + ump_dict = saml2_config.get("user_mapping_provider") or {} + + # Use the default user mapping provider if not set + ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + + # Ensure a config is present + ump_dict["config"] = ump_dict.get("config") or {} + + if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: + # Load deprecated options for use by the default module + old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") + if old_mxid_source_attribute: + logger.warning( + "The config option saml2_config.mxid_source_attribute is deprecated. " + "Please use saml2_config.user_mapping_provider.config" + ".mxid_source_attribute instead." + ) + ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute + + old_mxid_mapping = saml2_config.get("mxid_mapping") + if old_mxid_mapping: + logger.warning( + "The config option saml2_config.mxid_mapping is deprecated. Please " + "use saml2_config.user_mapping_provider.config.mxid_mapping instead." + ) + ump_dict["config"]["mxid_mapping"] = old_mxid_mapping + + # Retrieve an instance of the module's class + # Pass the config dictionary to the module for processing + ( + self.saml2_user_mapping_provider_class, + self.saml2_user_mapping_provider_config, + ) = load_module(ump_dict) + + # Ensure loaded user mapping module has defined all necessary methods + # Note parse_config() is already checked during the call to load_module + required_methods = [ + "get_saml_attributes", + "saml_response_to_user_attributes", + "get_remote_user_id", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.saml2_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by saml2_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + # Get the desired saml auth response attributes from the module + saml2_config_dict = self._default_saml_config_dict( + *self.saml2_user_mapping_provider_class.get_saml_attributes( + self.saml2_user_mapping_provider_config + ) + ) + _dict_merge( + merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict + ) config_path = saml2_config.get("config_path", None) if config_path is not None: - self.saml2_sp_config.load_file(config_path) + mod = load_python_module(config_path) + _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict) + + import saml2.config + + self.saml2_sp_config = saml2.config.SPConfig() + self.saml2_sp_config.load(saml2_config_dict) # session lifetime: in milliseconds self.saml2_session_lifetime = self.parse_duration( saml2_config.get("saml_session_lifetime", "5m") ) - def _default_saml_config_dict(self): + template_dir = saml2_config.get("template_dir") + if not template_dir: + template_dir = pkg_resources.resource_filename("synapse", "res/templates",) + + loader = jinja2.FileSystemLoader(template_dir) + # enable auto-escape here, to having to remember to escape manually in the + # template + env = jinja2.Environment(loader=loader, autoescape=True) + self.saml2_error_html_template = env.get_template("saml_error.html") + + def _default_saml_config_dict( + self, required_attributes: set, optional_attributes: set + ): + """Generate a configuration dictionary with required and optional attributes that + will be needed to process new user registration + + Args: + required_attributes: SAML auth response attributes that are + necessary to function + optional_attributes: SAML auth response attributes that can be used to add + additional information to Synapse user accounts, but are not required + + Returns: + dict: A SAML configuration dictionary + """ import saml2 public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") + if self.saml2_grandfathered_mxid_source_attribute: + optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) + optional_attributes -= required_attributes + metadata_url = public_baseurl + "_matrix/saml2/metadata.xml" response_url = public_baseurl + "_matrix/saml2/authn_response" return { @@ -66,22 +209,26 @@ class SAML2Config(Config): (response_url, saml2.BINDING_HTTP_POST) ] }, - "required_attributes": ["uid"], - "optional_attributes": ["mail", "surname", "givenname"], + "required_attributes": list(required_attributes), + "optional_attributes": list(optional_attributes), + # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT, } }, } def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ + ## Single sign-on integration ## + # Enable SAML2 for registration and login. Uses pysaml2. # - # `sp_config` is the configuration for the pysaml2 Service Provider. - # See pysaml2 docs for format of config. + # At least one of `sp_config` or `config_path` must be set in this section to + # enable SAML login. # - # Default values will be used for the 'entityid' and 'service' settings, - # so it is not normally necessary to specify them unless you need to - # override them. + # (You will probably also want to set the following options to `false` to + # disable the regular login/registration flows: + # * enable_registration + # * password_config.enabled # # Once SAML support is enabled, a metadata file will be exposed at # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to @@ -89,52 +236,135 @@ class SAML2Config(Config): # the IdP to use an ACS location of # https://<server>:<port>/_matrix/saml2/authn_response. # - #saml2_config: - # sp_config: - # # point this to the IdP's metadata. You can use either a local file or - # # (preferably) a URL. - # metadata: - # #local: ["saml2/idp.xml"] - # remote: - # - url: https://our_idp/metadata.xml - # - # # By default, the user has to go to our login page first. If you'd like to - # # allow IdP-initiated login, set 'allow_unsolicited: True' in a - # # 'service.sp' section: - # # - # #service: - # # sp: - # # allow_unsolicited: True - # - # # The examples below are just used to generate our metadata xml, and you - # # may well not need it, depending on your setup. Alternatively you - # # may need a whole lot more detail - see the pysaml2 docs! - # - # description: ["My awesome SP", "en"] - # name: ["Test SP", "en"] - # - # organization: - # name: Example com - # display_name: - # - ["Example co", "en"] - # url: "http://example.com" - # - # contact_person: - # - given_name: Bob - # sur_name: "the Sysadmin" - # email_address": ["admin@example.com"] - # contact_type": technical - # - # # Instead of putting the config inline as above, you can specify a - # # separate pysaml2 configuration file: - # # - # config_path: "%(config_dir_path)s/sp_conf.py" - # - # # the lifetime of a SAML session. This defines how long a user has to - # # complete the authentication process, if allow_unsolicited is unset. - # # The default is 5 minutes. - # # - # # saml_session_lifetime: 5m + saml2_config: + # `sp_config` is the configuration for the pysaml2 Service Provider. + # See pysaml2 docs for format of config. + # + # Default values will be used for the 'entityid' and 'service' settings, + # so it is not normally necessary to specify them unless you need to + # override them. + # + #sp_config: + # # point this to the IdP's metadata. You can use either a local file or + # # (preferably) a URL. + # metadata: + # #local: ["saml2/idp.xml"] + # remote: + # - url: https://our_idp/metadata.xml + # + # # By default, the user has to go to our login page first. If you'd like + # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a + # # 'service.sp' section: + # # + # #service: + # # sp: + # # allow_unsolicited: true + # + # # The examples below are just used to generate our metadata xml, and you + # # may well not need them, depending on your setup. Alternatively you + # # may need a whole lot more detail - see the pysaml2 docs! + # + # description: ["My awesome SP", "en"] + # name: ["Test SP", "en"] + # + # organization: + # name: Example com + # display_name: + # - ["Example co", "en"] + # url: "http://example.com" + # + # contact_person: + # - given_name: Bob + # sur_name: "the Sysadmin" + # email_address": ["admin@example.com"] + # contact_type": technical + + # Instead of putting the config inline as above, you can specify a + # separate pysaml2 configuration file: + # + #config_path: "%(config_dir_path)s/sp_conf.py" + + # The lifetime of a SAML session. This defines how long a user has to + # complete the authentication process, if allow_unsolicited is unset. + # The default is 5 minutes. + # + #saml_session_lifetime: 5m + + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. + # + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider + + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. + # + # This setting controls the SAML attribute which will be used for this + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. + # + # The default is 'uid'. + # + #grandfathered_mxid_source_attribute: upn + + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page to display to users if something goes wrong during the + # authentication process: 'saml_error.html'. + # + # When rendering, this template is given the following variables: + # * code: an HTML error code corresponding to the error that is being + # returned (typically 400 or 500) + # + # * msg: a textual message describing the error. + # + # The variables will automatically be HTML-escaped. + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" """ % { "config_dir_path": config_dir_path } diff --git a/synapse/config/server.py b/synapse/config/server.py index 7f8d315954..f57eefc99c 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,6 +19,7 @@ import logging import os.path import re from textwrap import indent +from typing import Dict, List, Optional import attr import yaml @@ -40,7 +41,7 @@ logger = logging.Logger(__name__) # in the list. DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] -DEFAULT_ROOM_VERSION = "4" +DEFAULT_ROOM_VERSION = "5" ROOM_COMPLEXITY_TOO_GREAT = ( "Your homeserver is unable to join rooms this large or complex. " @@ -48,8 +49,17 @@ ROOM_COMPLEXITY_TOO_GREAT = ( "to join this room." ) +METRICS_PORT_WARNING = """\ +The metrics_port configuration option is deprecated in Synapse 0.31 in favour of +a listener. Please see +https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md +on how to configure the new listener. +--------------------------------------------------------------------------------""" + class ServerConfig(Config): + section = "server" + def read_config(self, config, **kwargs): self.server_name = config["server_name"] self.server_context = config.get("server_context", None) @@ -92,6 +102,12 @@ class ServerConfig(Config): "require_auth_for_profile_requests", False ) + # Whether to require sharing a room with a user to retrieve their + # profile data + self.limit_profile_requests_to_users_who_share_rooms = config.get( + "limit_profile_requests_to_users_who_share_rooms", False, + ) + if "restrict_public_rooms_to_local_users" in config and ( "allow_public_rooms_without_auth" in config or "allow_public_rooms_over_federation" in config @@ -108,15 +124,16 @@ class ServerConfig(Config): self.allow_public_rooms_without_auth = False self.allow_public_rooms_over_federation = False else: - # If set to 'False', requires authentication to access the server's public - # rooms directory through the client API. Defaults to 'True'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. self.allow_public_rooms_without_auth = config.get( - "allow_public_rooms_without_auth", True + "allow_public_rooms_without_auth", False ) - # If set to 'False', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'True'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. self.allow_public_rooms_over_federation = config.get( - "allow_public_rooms_over_federation", True + "allow_public_rooms_over_federation", False ) default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) @@ -161,6 +178,7 @@ class ServerConfig(Config): ) self.mau_trial_days = config.get("mau_trial_days", 0) + self.mau_limit_alerting = config.get("mau_limit_alerting", True) # How long to keep redacted events in the database in unredacted form # before redacting them. @@ -172,17 +190,23 @@ class ServerConfig(Config): else: self.redaction_retention_period = None + # How long to keep entries in the `users_ips` table. + user_ips_max_age = config.get("user_ips_max_age", "28d") + if user_ips_max_age is not None: + self.user_ips_max_age = self.parse_duration(user_ips_max_age) + else: + self.user_ips_max_age = None + # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled_message = config.get("hs_disabled_message", "") - self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") # Admin uri to direct users at should their instance become blocked # due to resource constraints self.admin_contact = config.get("admin_contact", None) # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None + self.federation_domain_whitelist = None # type: Optional[dict] federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: @@ -206,7 +230,7 @@ class ServerConfig(Config): self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) except Exception as e: raise ConfigError( - "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e + "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e ) if self.public_baseurl is not None: @@ -229,7 +253,133 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) - self.listeners = [] + retention_config = config.get("retention") + if retention_config is None: + retention_config = {} + + self.retention_enabled = retention_config.get("enabled", False) + + retention_default_policy = retention_config.get("default_policy") + + if retention_default_policy is not None: + self.retention_default_min_lifetime = retention_default_policy.get( + "min_lifetime" + ) + if self.retention_default_min_lifetime is not None: + self.retention_default_min_lifetime = self.parse_duration( + self.retention_default_min_lifetime + ) + + self.retention_default_max_lifetime = retention_default_policy.get( + "max_lifetime" + ) + if self.retention_default_max_lifetime is not None: + self.retention_default_max_lifetime = self.parse_duration( + self.retention_default_max_lifetime + ) + + if ( + self.retention_default_min_lifetime is not None + and self.retention_default_max_lifetime is not None + and ( + self.retention_default_min_lifetime + > self.retention_default_max_lifetime + ) + ): + raise ConfigError( + "The default retention policy's 'min_lifetime' can not be greater" + " than its 'max_lifetime'" + ) + else: + self.retention_default_min_lifetime = None + self.retention_default_max_lifetime = None + + if self.retention_enabled: + logger.info( + "Message retention policies support enabled with the following default" + " policy: min_lifetime = %s ; max_lifetime = %s", + self.retention_default_min_lifetime, + self.retention_default_max_lifetime, + ) + + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) + if self.retention_allowed_lifetime_min is not None: + self.retention_allowed_lifetime_min = self.parse_duration( + self.retention_allowed_lifetime_min + ) + + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) + if self.retention_allowed_lifetime_max is not None: + self.retention_allowed_lifetime_max = self.parse_duration( + self.retention_allowed_lifetime_max + ) + + if ( + self.retention_allowed_lifetime_min is not None + and self.retention_allowed_lifetime_max is not None + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max + ): + raise ConfigError( + "Invalid retention policy limits: 'allowed_lifetime_min' can not be" + " greater than 'allowed_lifetime_max'" + ) + + self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] + for purge_job_config in retention_config.get("purge_jobs", []): + interval_config = purge_job_config.get("interval") + + if interval_config is None: + raise ConfigError( + "A retention policy's purge jobs configuration must have the" + " 'interval' key set." + ) + + interval = self.parse_duration(interval_config) + + shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") + + if shortest_max_lifetime is not None: + shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) + + longest_max_lifetime = purge_job_config.get("longest_max_lifetime") + + if longest_max_lifetime is not None: + longest_max_lifetime = self.parse_duration(longest_max_lifetime) + + if ( + shortest_max_lifetime is not None + and longest_max_lifetime is not None + and shortest_max_lifetime > longest_max_lifetime + ): + raise ConfigError( + "A retention policy's purge jobs configuration's" + " 'shortest_max_lifetime' value can not be greater than its" + " 'longest_max_lifetime' value." + ) + + self.retention_purge_jobs.append( + { + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + } + ) + + if not self.retention_purge_jobs: + self.retention_purge_jobs = [ + { + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + } + ] + + self.listeners = [] # type: List[dict] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): raise ConfigError( @@ -273,7 +423,10 @@ class ServerConfig(Config): validator=attr.validators.instance_of(bool), default=False ) complexity = attr.ib( - validator=attr.validators.instance_of((int, float)), default=1.0 + validator=attr.validators.instance_of( + (float, int) # type: ignore[arg-type] # noqa + ), + default=1.0, ) complexity_error = attr.ib( validator=attr.validators.instance_of(str), @@ -281,7 +434,7 @@ class ServerConfig(Config): ) self.limit_remote_rooms = LimitRemoteRoomsConfig( - **config.get("limit_remote_rooms", {}) + **(config.get("limit_remote_rooms") or {}) ) bind_port = config.get("bind_port") @@ -334,14 +487,7 @@ class ServerConfig(Config): metrics_port = config.get("metrics_port") if metrics_port: - logger.warn( - ( - "The metrics_port configuration option is deprecated in Synapse 0.31 " - "in favour of a listener. Please see " - "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md" - " on how to configure the new listener." - ) - ) + logger.warning(METRICS_PORT_WARNING) self.listeners.append( { @@ -355,14 +501,28 @@ class ServerConfig(Config): _check_resource_config(self.listeners) - # An experimental option to try and periodically clean up extremities - # by sending dummy events. self.cleanup_extremities_with_dummy_events = config.get( - "cleanup_extremities_with_dummy_events", False + "cleanup_extremities_with_dummy_events", True + ) + + # The number of forward extremities in a room needed to send a dummy event. + self.dummy_events_threshold = config.get("dummy_events_threshold", 10) + + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + + # Inhibits the /requestToken endpoints from returning an error that might leak + # information about whether an e-mail address is in use or not on this + # homeserver, and instead return a 200 with a fake sid if this kind of error is + # met, without sending anything. + # This is a compromise between sending an email, which could be a spam vector, + # and letting the client know which email address is bound to an account and + # which one isn't. + self.request_token_inhibit_3pid_errors = config.get( + "request_token_inhibit_3pid_errors", False, ) - def has_tls_listener(self): - return any(l["tls"] for l in self.listeners) + def has_tls_listener(self) -> bool: + return any(listener["tls"] for listener in self.listeners) def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs @@ -458,10 +618,15 @@ class ServerConfig(Config): # pid_file: %(pid_file)s - # The path to the web client which will be served at /_matrix/client/ - # if 'webclient' is configured under the 'listeners' configuration. + # The absolute URL to the web client which /_matrix/client will redirect + # to if 'webclient' is configured under the 'listeners' configuration. + # + # This option can be also set to the filesystem path to the web client + # which will be served at /_matrix/client/ if 'webclient' is configured + # under the 'listeners' configuration, however this is a security risk: + # https://github.com/matrix-org/synapse#security-note # - #web_client_location: "/path/to/web/root" + #web_client_location: https://riot.example.com/ # The public-facing base URL that clients use to access this HS # (not including _matrix/...). This is the same URL a user would @@ -489,15 +654,23 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true - # If set to 'false', requires authentication to access the server's public rooms - # directory through the client API. Defaults to 'true'. + # Uncomment to require a user to share a room with another user in order + # to retrieve their profile information. Only checked on Client-Server + # requests. Profile requests from other servers should be checked by the + # requesting server. Defaults to 'false'. # - #allow_public_rooms_without_auth: false + #limit_profile_requests_to_users_who_share_rooms: true - # If set to 'false', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'true'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. # - #allow_public_rooms_over_federation: false + #allow_public_rooms_without_auth: true + + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. + # + #allow_public_rooms_over_federation: true # The default room version for newly created rooms. # @@ -521,7 +694,7 @@ class ServerConfig(Config): # Whether room invites to users on this server should be blocked # (except those sent by local server admins). The default is False. # - #block_non_admin_invites: True + #block_non_admin_invites: true # Room searching # @@ -545,6 +718,9 @@ class ServerConfig(Config): # blacklist IP address CIDR ranges. If this option is not specified, or # specified with an empty list, no ip range blacklist will be enforced. # + # As of Synapse v1.4.0 this option also affects any outbound requests to identity + # servers provided by user input. + # # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly # listed here, since they correspond to unroutable addresses.) # @@ -650,6 +826,18 @@ class ServerConfig(Config): # bind_addresses: ['::1', '127.0.0.1'] # type: manhole + # Forward extremities can build up in a room due to networking delays between + # homeservers. Once this happens in a large room, calculation of the state of + # that room can become quite expensive. To mitigate this, once the number of + # forward extremities reaches a given threshold, Synapse will send an + # org.matrix.dummy_event event, which will reduce the forward extremities + # in the room. + # + # This setting defines the threshold (i.e. number of forward extremities in the + # room) at which dummy events are sent. The default value is 10. + # + #dummy_events_threshold: 5 + ## Homeserver blocking ## @@ -659,9 +847,8 @@ class ServerConfig(Config): # Global blocking # - #hs_disabled: False + #hs_disabled: false #hs_disabled_message: 'Human readable reason for why the HS is blocked' - #hs_disabled_limit_type: 'error code(str), to help clients decode reason' # Monthly Active User Blocking # @@ -681,15 +868,22 @@ class ServerConfig(Config): # sign up in a short space of time never to return after their initial # session. # - #limit_usage_by_mau: False + # 'mau_limit_alerting' is a means of limiting client side alerting + # should the mau limit be reached. This is useful for small instances + # where the admin has 5 mau seats (say) for 5 specific people and no + # interest increasing the mau limit further. Defaults to True, which + # means that alerting is enabled + # + #limit_usage_by_mau: false #max_mau_value: 50 #mau_trial_days: 2 + #mau_limit_alerting: false # If enabled, the metrics for the number of monthly active users will # be populated, however no one will be limited. If limit_usage_by_mau # is true, this is implied to be true. # - #mau_stats_only: False + #mau_stats_only: false # Sometimes the server admin will want to ensure certain accounts are # never blocked by mau checking. These accounts are specified here. @@ -701,22 +895,27 @@ class ServerConfig(Config): # Used by phonehome stats to group together related servers. #server_context: context - # Resource-constrained Homeserver Settings + # Resource-constrained homeserver settings # - # If limit_remote_rooms.enabled is True, the room complexity will be - # checked before a user joins a new remote room. If it is above - # limit_remote_rooms.complexity, it will disallow joining or - # instantly leave. + # When this is enabled, the room "complexity" will be checked before a user + # joins a new remote room. If it is above the complexity limit, the server will + # disallow joining, or will instantly leave. # - # limit_remote_rooms.complexity_error can be set to customise the text - # displayed to the user when a room above the complexity threshold has - # its join cancelled. + # Room complexity is an arbitrary measure based on factors such as the number of + # users in the room. # - # Uncomment the below lines to enable: - #limit_remote_rooms: - # enabled: True - # complexity: 1.0 - # complexity_error: "This room is too complex." + limit_remote_rooms: + # Uncomment to enable room complexity checking. + # + #enabled: true + + # the limit above which rooms cannot be joined. The default is 1.0. + # + #complexity: 0.5 + + # override the error which is returned when the room is too complex. + # + #complexity_error: "This room is too complex." # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. @@ -734,7 +933,86 @@ class ServerConfig(Config): # # Defaults to `7d`. Set to `null` to disable. # - redaction_retention_period: 7d + #redaction_retention_period: 28d + + # How long to track users' last seen time and IPs in the database. + # + # Defaults to `28d`. Set to `null` to disable clearing out of old rows. + # + #user_ips_max_age: 14d + + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a more frequent basis than for the rest of the rooms + # (e.g. every 12h), but not want that purge to be performed by a job that's + # iterating over every room it knows, which could be heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 12h + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 1d + + # Inhibits the /requestToken endpoints from returning an error that might leak + # information about whether an e-mail address is in use or not on this + # homeserver. + # Note that for some endpoints the error situation is the e-mail already being + # used, and for others the error is entering the e-mail being unused. + # If this option is enabled, instead of returning an error, these endpoints will + # act as if no error happened and return a fake session ID ('sid') to clients. + # + #request_token_inhibit_3pid_errors: true """ % locals() ) @@ -755,20 +1033,20 @@ class ServerConfig(Config): "--daemonize", action="store_true", default=None, - help="Daemonize the home server", + help="Daemonize the homeserver", ) server_group.add_argument( "--print-pidfile", action="store_true", default=None, - help="Print the path to the pidfile just" " before daemonizing", + help="Print the path to the pidfile just before daemonizing", ) server_group.add_argument( "--manhole", metavar="PORT", dest="manhole", type=int, - help="Turn on the twisted telnet manhole" " service on the given port.", + help="Turn on the twisted telnet manhole service on the given port.", ) @@ -834,12 +1112,12 @@ KNOWN_RESOURCES = ( def _check_resource_config(listeners): - resource_names = set( + resource_names = { res_name for listener in listeners for res in listener.get("resources", []) for res_name in res.get("names", []) - ) + } for resource in resource_names: if resource not in KNOWN_RESOURCES: diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index eaac3d73bc..6c427b6f92 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -51,7 +51,7 @@ class ServerNoticesConfig(Config): None if server notices are not enabled. server_notices_mxid_avatar_url (str|None): - The display name to use for the server notices user. + The MXC URL for the avatar of the server notices user. None if server notices are not enabled. server_notices_room_name (str|None): @@ -59,8 +59,10 @@ class ServerNoticesConfig(Config): None if server notices are not enabled. """ - def __init__(self): - super(ServerNoticesConfig, self).__init__() + section = "servernotices" + + def __init__(self, *args): + super(ServerNoticesConfig, self).__init__(*args) self.server_notices_mxid = None self.server_notices_mxid_display_name = None self.server_notices_mxid_avatar_url = None diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index e40797ab50..3d067d29db 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -13,23 +13,47 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Tuple + +from synapse.config import ConfigError from synapse.util.module_loader import load_module from ._base import Config class SpamCheckerConfig(Config): + section = "spamchecker" + def read_config(self, config, **kwargs): - self.spam_checker = None + self.spam_checkers = [] # type: List[Tuple[Any, Dict]] + + spam_checkers = config.get("spam_checker") or [] + if isinstance(spam_checkers, dict): + # The spam_checker config option used to only support one + # spam checker, and thus was simply a dictionary with module + # and config keys. Support this old behaviour by checking + # to see if the option resolves to a dictionary + self.spam_checkers.append(load_module(spam_checkers)) + elif isinstance(spam_checkers, list): + for spam_checker in spam_checkers: + if not isinstance(spam_checker, dict): + raise ConfigError("spam_checker syntax is incorrect") - provider = config.get("spam_checker", None) - if provider is not None: - self.spam_checker = load_module(provider) + self.spam_checkers.append(load_module(spam_checker)) + else: + raise ConfigError("spam_checker syntax is incorrect") def generate_config_section(self, **kwargs): return """\ - #spam_checker: - # module: "my_custom_project.SuperSpamChecker" - # config: - # example_option: 'things' + # Spam checkers are third-party modules that can block specific actions + # of local users, such as creating rooms and registering undesirable + # usernames, as well as remote users by redacting incoming events. + # + spam_checker: + #- module: "my_custom_project.SuperSpamChecker" + # config: + # example_option: 'things' + #- module: "some_other_project.BadEventStopper" + # config: + # example_stop_events_from: ['@bad:example.com'] """ diff --git a/synapse/config/sso.py b/synapse/config/sso.py new file mode 100644 index 0000000000..73b7296399 --- /dev/null +++ b/synapse/config/sso.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Dict + +import pkg_resources + +from ._base import Config + + +class SSOConfig(Config): + """SSO Configuration + """ + + section = "sso" + + def read_config(self, config, **kwargs): + sso_config = config.get("sso") or {} # type: Dict[str, Any] + + # Pick a template directory in order of: + # * The sso-specific template_dir + # * /path/to/synapse/install/res/templates + template_dir = sso_config.get("template_dir") + if not template_dir: + template_dir = pkg_resources.resource_filename("synapse", "res/templates",) + + self.sso_template_dir = template_dir + self.sso_account_deactivated_template = self.read_file( + os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), + "sso_account_deactivated_template", + ) + self.sso_auth_success_template = self.read_file( + os.path.join(self.sso_template_dir, "sso_auth_success.html"), + "sso_auth_success_template", + ) + + self.sso_client_whitelist = sso_config.get("client_whitelist") or [] + + # Attempt to also whitelist the server's login fallback, since that fallback sets + # the redirect URL to itself (so it can process the login token then return + # gracefully to the client). This would make it pointless to ask the user for + # confirmation, since the URL the confirmation page would be showing wouldn't be + # the client's. + # public_baseurl is an optional setting, so we only add the fallback's URL to the + # list if it's provided (because we can't figure out what that URL is otherwise). + if self.public_baseurl: + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) + + def generate_config_section(self, **kwargs): + return """\ + # Additional settings to use with single-sign on systems such as OpenID Connect, + # SAML2 and CAS. + # + sso: + # A list of client URLs which are whitelisted so that the user does not + # have to confirm giving access to their account to the URL. Any client + # whose URL starts with an entry in the following list will not be subject + # to an additional confirmation step after the SSO login is completed. + # + # WARNING: An entry such as "https://my.client" is insecure, because it + # will also match "https://my.client.evil.site", exposing your users to + # phishing attacks from evil.site. To avoid this, include a slash after the + # hostname: "https://my.client/". + # + # If public_baseurl is set, then the login fallback page (used by clients + # that don't natively support the required login flows) is whitelisted in + # addition to any URLs in this list. + # + # By default, this list is empty. + # + #client_whitelist: + # - https://riot.im/develop + # - https://my.custom.client/ + + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * HTML page for a confirmation step before redirecting back to the client + # with the login token: 'sso_redirect_confirm.html'. + # + # When rendering, this template is given three variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * display_url: the same as `redirect_url`, but with the query + # parameters stripped. The intention is to have a + # human-readable URL to show to users, not to use it as + # the final address to redirect to. Needs manual escaping + # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # * HTML page which notifies the user that they are authenticating to confirm + # an operation on their account during the user interactive authentication + # process: 'sso_auth_confirm.html'. + # + # When rendering, this template is given the following variables: + # * redirect_url: the URL the user is about to be redirected to. Needs + # manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * description: the operation which the user is being asked to confirm + # + # * HTML page shown after a successful user interactive authentication session: + # 'sso_auth_success.html'. + # + # Note that this page must include the JavaScript which notifies of a successful authentication + # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). + # + # This template has no additional variables. + # + # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) + # attempts to login: 'sso_account_deactivated.html'. + # + # This template has no additional variables. + # + # * HTML page to display to users if something goes wrong during the + # OpenID Connect authentication process: 'sso_error.html'. + # + # When rendering, this template is given two variables: + # * error: the technical name of the error + # * error_description: a human-readable message for the error + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" + """ diff --git a/synapse/config/stats.py b/synapse/config/stats.py index b18ddbd1fa..62485189ea 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -25,6 +25,8 @@ class StatsConfig(Config): Configuration for the behaviour of synapse's stats engine """ + section = "stats" + def read_config(self, config, **kwargs): self.stats_enabled = True self.stats_bucket_size = 86400 * 1000 diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index b3431441b9..10a99c792e 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -19,6 +19,8 @@ from ._base import Config class ThirdPartyRulesConfig(Config): + section = "thirdpartyrules" + def read_config(self, config, **kwargs): self.third_party_event_rules = None diff --git a/synapse/config/tls.py b/synapse/config/tls.py index fc47ba3e9a..a65538562b 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -18,6 +18,7 @@ import os import warnings from datetime import datetime from hashlib import sha256 +from typing import List import six @@ -31,9 +32,22 @@ from synapse.util import glob_to_regex logger = logging.getLogger(__name__) +ACME_SUPPORT_ENABLED_WARN = """\ +This server uses Synapse's built-in ACME support. Note that ACME v1 has been +deprecated by Let's Encrypt, and that Synapse doesn't currently support ACME v2, +which means that this feature will not work with Synapse installs set up after +November 2019, and that it may stop working on June 2020 for installs set up +before that date. + +For more info and alternative solutions, see +https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1 +--------------------------------------------------------------------------------""" + class TlsConfig(Config): - def read_config(self, config, config_dir_path, **kwargs): + section = "tls" + + def read_config(self, config: dict, config_dir_path: str, **kwargs): acme_config = config.get("acme", None) if acme_config is None: @@ -41,6 +55,9 @@ class TlsConfig(Config): self.acme_enabled = acme_config.get("enabled", False) + if self.acme_enabled: + logger.warning(ACME_SUPPORT_ENABLED_WARN) + # hyperlink complains on py2 if this is not a Unicode self.acme_url = six.text_type( acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory") @@ -57,7 +74,7 @@ class TlsConfig(Config): self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) - if self.has_tls_listener(): + if self.root.server.has_tls_listener(): if not self.tls_certificate_file: raise ConfigError( "tls_certificate_path must be specified if TLS-enabled listeners are " @@ -106,9 +123,11 @@ class TlsConfig(Config): fed_whitelist_entries = config.get( "federation_certificate_verification_whitelist", [] ) + if fed_whitelist_entries is None: + fed_whitelist_entries = [] # Support globs (*) in whitelist values - self.federation_certificate_verification_whitelist = [] + self.federation_certificate_verification_whitelist = [] # type: List[str] for entry in fed_whitelist_entries: try: entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) @@ -241,7 +260,7 @@ class TlsConfig(Config): crypto.FILETYPE_ASN1, self.tls_certificate ) sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) - sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) + sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints} if sha256_fingerprint not in sha256_fingerprints: self.tls_fingerprints.append({"sha256": sha256_fingerprint}) @@ -286,6 +305,9 @@ class TlsConfig(Config): "http://localhost:8009/.well-known/acme-challenge" ) + # flake8 doesn't recognise that variables are used in the below string + _ = tls_enabled, proxypassline, acme_enabled, default_acme_account_file + return ( """\ ## TLS ## @@ -354,6 +376,11 @@ class TlsConfig(Config): # ACME support: This will configure Synapse to request a valid TLS certificate # for your configured `server_name` via Let's Encrypt. # + # Note that ACME v1 is now deprecated, and Synapse currently doesn't support + # ACME v2. This means that this feature currently won't work with installs set + # up after November 2019. For more info, and alternative solutions, see + # https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1 + # # Note that provisioning a certificate in this way requires port 80 to be # routed to Synapse so that it can complete the http-01 ACME challenge. # By default, if you enable ACME support, Synapse will attempt to listen on @@ -448,7 +475,11 @@ class TlsConfig(Config): #tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] """ - % locals() + # Lowercase the string representation of boolean values + % { + x[0]: str(x[1]).lower() if isinstance(x[1], bool) else x[1] + for x in locals().items() + } ) def read_tls_certificate(self): diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 85d99a3166..8be1346113 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -19,6 +19,8 @@ from ._base import Config, ConfigError class TracerConfig(Config): + section = "tracing" + def read_config(self, config, **kwargs): opentracing_config = config.get("opentracing") if opentracing_config is None: diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index f6313e17d4..c8d19c5d6b 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -21,6 +21,8 @@ class UserDirectoryConfig(Config): Configuration for the behaviour of the /user_directory API """ + section = "userdirectory" + def read_config(self, config, **kwargs): self.user_directory_search_enabled = True self.user_directory_search_all_users = False diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 2ca0e1cf70..b313bff140 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -16,6 +16,8 @@ from ._base import Config class VoipConfig(Config): + section = "voip" + def read_config(self, config, **kwargs): self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") @@ -54,5 +56,5 @@ class VoipConfig(Config): # connect to arbitrary endpoints without having first signed up for a # valid account (e.g. by passing a CAPTCHA). # - #turn_allow_guests: True + #turn_allow_guests: true """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index bc0fc165e3..ed06b91a54 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 matrix.org +# Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +import attr + +from ._base import Config, ConfigError + + +@attr.s +class InstanceLocationConfig: + """The host and port to talk to an instance via HTTP replication. + """ + + host = attr.ib(type=str) + port = attr.ib(type=int) + + +@attr.s +class WriterLocations: + """Specifies the instances that write various streams. + + Attributes: + events: The instance that writes to the event and backfill streams. + """ + + events = attr.ib(default="master", type=str) class WorkerConfig(Config): @@ -21,6 +43,8 @@ class WorkerConfig(Config): They have their own pid_file and listener configuration. They use the replication_url to talk to the main synapse process.""" + section = "worker" + def read_config(self, config, **kwargs): self.worker_app = config.get("worker_app") @@ -69,6 +93,27 @@ class WorkerConfig(Config): elif not bind_addresses: bind_addresses.append("") + # A map from instance name to host/port of their HTTP replication endpoint. + instance_map = config.get("instance_map") or {} + self.instance_map = { + name: InstanceLocationConfig(**c) for name, c in instance_map.items() + } + + # Map from type of streams to source, c.f. WriterLocations. + writers = config.get("stream_writers") or {} + self.writers = WriterLocations(**writers) + + # Check that the configured writer for events also appears in + # `instance_map`. + if ( + self.writers.events != "master" + and self.writers.events not in self.instance_map + ): + raise ConfigError( + "Instance %r is configured to write events but does not appear in `instance_map` config." + % (self.writers.events,) + ) + def read_arguments(self, args): # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index e93f0b3705..a5a2a7815d 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -75,7 +75,7 @@ class ServerContextFactory(ContextFactory): @implementer(IPolicyForHTTPS) -class ClientTLSOptionsFactory(object): +class FederationPolicyForHTTPS(object): """Factory for Twisted SSLClientConnectionCreators that are used to make connections to remote servers for federation. @@ -103,15 +103,15 @@ class ClientTLSOptionsFactory(object): # let us do). minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version] - self._verify_ssl = CertificateOptions( + _verify_ssl = CertificateOptions( trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS ) - self._verify_ssl_context = self._verify_ssl.getContext() - self._verify_ssl_context.set_info_callback(self._context_info_cb) + self._verify_ssl_context = _verify_ssl.getContext() + self._verify_ssl_context.set_info_callback(_context_info_cb) - self._no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) - self._no_verify_ssl_context = self._no_verify_ssl.getContext() - self._no_verify_ssl_context.set_info_callback(self._context_info_cb) + _no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS) + self._no_verify_ssl_context = _no_verify_ssl.getContext() + self._no_verify_ssl_context.set_info_callback(_context_info_cb) def get_options(self, host: bytes): @@ -136,23 +136,6 @@ class ClientTLSOptionsFactory(object): return SSLClientConnectionCreator(host, ssl_context, should_verify) - @staticmethod - def _context_info_cb(ssl_connection, where, ret): - """The 'information callback' for our openssl context object.""" - # we assume that the app_data on the connection object has been set to - # a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator) - tls_protocol = ssl_connection.get_app_data() - try: - # ... we further assume that SSLClientConnectionCreator has set the - # '_synapse_tls_verifier' attribute to a ConnectionVerifier object. - tls_protocol._synapse_tls_verifier.verify_context_info_cb( - ssl_connection, where - ) - except: # noqa: E722, taken from the twisted implementation - logger.exception("Error during info_callback") - f = Failure() - tls_protocol.failVerification(f) - def creatorForNetloc(self, hostname, port): """Implements the IPolicyForHTTPS interace so that this can be passed directly to agents. @@ -160,6 +143,43 @@ class ClientTLSOptionsFactory(object): return self.get_options(hostname) +@implementer(IPolicyForHTTPS) +class RegularPolicyForHTTPS(object): + """Factory for Twisted SSLClientConnectionCreators that are used to make connections + to remote servers, for other than federation. + + Always uses the same OpenSSL context object, which uses the default OpenSSL CA + trust root. + """ + + def __init__(self): + trust_root = platformTrust() + self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext() + self._ssl_context.set_info_callback(_context_info_cb) + + def creatorForNetloc(self, hostname, port): + return SSLClientConnectionCreator(hostname, self._ssl_context, True) + + +def _context_info_cb(ssl_connection, where, ret): + """The 'information callback' for our openssl context objects. + + Note: Once this is set as the info callback on a Context object, the Context should + only be used with the SSLClientConnectionCreator. + """ + # we assume that the app_data on the connection object has been set to + # a TLSMemoryBIOProtocol object. (This is done by SSLClientConnectionCreator) + tls_protocol = ssl_connection.get_app_data() + try: + # ... we further assume that SSLClientConnectionCreator has set the + # '_synapse_tls_verifier' attribute to a ConnectionVerifier object. + tls_protocol._synapse_tls_verifier.verify_context_info_cb(ssl_connection, where) + except: # noqa: E722, taken from the twisted implementation + logger.exception("Error during info_callback") + f = Failure() + tls_protocol.failVerification(f) + + @implementer(IOpenSSLClientConnectionCreator) class SSLClientConnectionCreator(object): """Creates openssl connection objects for client connections. diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 694fb2c816..0422c43fab 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- - +# # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import collections.abc import hashlib import logging +from typing import Dict from canonicaljson import encode_canonical_json from signedjson.sign import sign_json +from signedjson.types import SigningKey from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import RoomVersion from synapse.events.utils import prune_event, prune_event_dict +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -40,8 +45,11 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): # some malformed events lack a 'hashes'. Protect against it being missing # or a weird type by basically treating it the same as an unhashed event. hashes = event.get("hashes") - if not isinstance(hashes, dict): - raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED) + # nb it might be a frozendict or a dict + if not isinstance(hashes, collections.abc.Mapping): + raise SynapseError( + 400, "Malformed 'hashes': %s" % (type(hashes),), Codes.UNAUTHORIZED + ) if name not in hashes: raise SynapseError( @@ -109,46 +117,61 @@ def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): return hashed.name, hashed.digest() -def compute_event_signature(event_dict, signature_name, signing_key): +def compute_event_signature( + room_version: RoomVersion, + event_dict: JsonDict, + signature_name: str, + signing_key: SigningKey, +) -> Dict[str, Dict[str, str]]: """Compute the signature of the event for the given name and key. Args: - event_dict (dict): The event as a dict - signature_name (str): The name of the entity signing the event + room_version: the version of the room that this event is in. + (the room version determines the redaction algorithm and hence the + json to be signed) + + event_dict: The event as a dict + + signature_name: The name of the entity signing the event (typically the server's hostname). - signing_key (syutil.crypto.SigningKey): The key to sign with + + signing_key: The key to sign with Returns: - dict[str, dict[str, str]]: Returns a dictionary in the same format of - an event's signatures field. + a dictionary in the same format of an event's signatures field. """ - redact_json = prune_event_dict(event_dict) + redact_json = prune_event_dict(room_version, event_dict) redact_json.pop("age_ts", None) redact_json.pop("unsigned", None) - logger.debug("Signing event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signing event: %s", encode_canonical_json(redact_json)) redact_json = sign_json(redact_json, signature_name, signing_key) - logger.debug("Signed event: %s", encode_canonical_json(redact_json)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Signed event: %s", encode_canonical_json(redact_json)) return redact_json["signatures"] def add_hashes_and_signatures( - event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256 + room_version: RoomVersion, + event_dict: JsonDict, + signature_name: str, + signing_key: SigningKey, ): """Add content hash and sign the event Args: - event_dict (dict): The event to add hashes to and sign - signature_name (str): The name of the entity signing the event + room_version: the version of the room this event is in + + event_dict: The event to add hashes to and sign + signature_name: The name of the entity signing the event (typically the server's hostname). - signing_key (syutil.crypto.SigningKey): The key to sign with - hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use - to hash the event + signing_key: The key to sign with """ - name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm) + name, digest = compute_content_hash(event_dict, hash_algorithm=hashlib.sha256) event_dict.setdefault("hashes", {})[name] = encode_base64(digest) event_dict["signatures"] = compute_event_signature( - event_dict, signature_name=signature_name, signing_key=signing_key + room_version, event_dict, signature_name=signature_name, signing_key=signing_key ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 7cfad192e8..a9f4025bfe 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -43,8 +43,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -236,7 +236,7 @@ class Keyring(object): """ try: - ctx = LoggingContext.current_context() + ctx = current_context() # map from server name to a set of outstanding request ids server_to_request_ids = {} @@ -326,9 +326,7 @@ class Keyring(object): verify_requests (list[VerifyJsonRequest]): list of verify requests """ - remaining_requests = set( - (rq for rq in verify_requests if not rq.key_ready.called) - ) + remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @defer.inlineCallbacks def do_iterations(): @@ -396,7 +394,7 @@ class Keyring(object): results = yield fetcher.get_keys(missing_keys) - completed = list() + completed = [] for verify_request in remaining_requests: server_name = verify_request.server_name @@ -511,17 +509,18 @@ class BaseV2KeyFetcher(object): server_name = response_json["server_name"] verified = False for key_id in response_json["signatures"].get(server_name, {}): - # each of the keys used for the signature must be present in the response - # json. key = verify_keys.get(key_id) if not key: - raise KeyLookupError( - "Key response is signed by key id %s:%s but that key is not " - "present in the response" % (server_name, key_id) - ) + # the key may not be present in verify_keys if: + # * we got the key from the notary server, and: + # * the key belongs to the notary server, and: + # * the notary server is using a different key to sign notary + # responses. + continue verify_signed_json(response_json, server_name, key.verify_key) verified = True + break if not verified: raise KeyLookupError( diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4e91df60e6..c582355146 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # limitations under the License. import logging +from typing import List, Optional, Set, Tuple from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -22,17 +24,28 @@ from unpaddedbase64 import decode_base64 from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, EventSizeError, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions -from synapse.types import UserID, get_domain_from_id +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersion, +) +from synapse.events import EventBase +from synapse.types import StateMap, UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True): +def check( + room_version_obj: RoomVersion, + event: EventBase, + auth_events: StateMap[EventBase], + do_sig_check: bool = True, + do_size_check: bool = True, +) -> None: """ Checks if this event is correctly authed. Args: - room_version (str): the version of the room + room_version_obj: the version of the room event: the event being checked. auth_events (dict: event-key -> event): the existing room state. @@ -42,12 +55,26 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru Returns: if the auth checks pass. """ + assert isinstance(auth_events, dict) + if do_size_check: _check_size_limits(event) if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) + room_id = event.room_id + + # I'm not really expecting to get auth events in the wrong room, but let's + # sanity-check it + for auth_event in auth_events.values(): + if auth_event.room_id != room_id: + raise Exception( + "During auth for event %s in room %s, found event %s in the state " + "which is in room %s" + % (event.event_id, room_id, auth_event.event_id, auth_event.room_id) + ) + if do_sig_check: sender_domain = get_domain_from_id(event.sender) @@ -74,13 +101,12 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) - return - + # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules + # + # 1. If type is m.room.create: if event.type == EventTypes.Create: + # 1b. If the domain of the room_id does not match the domain of the sender, + # reject. sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) if room_id_domain != sender_domain: @@ -88,37 +114,45 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru 403, "Creation event's room_id domain does not match sender's" ) - room_version = event.content.get("room_version", "1") - if room_version not in KNOWN_ROOM_VERSIONS: + # 1c. If content.room_version is present and is not a recognised version, reject + room_version_prop = event.content.get("room_version", "1") + if room_version_prop not in KNOWN_ROOM_VERSIONS: raise AuthError( - 403, "room appears to have unsupported version %s" % (room_version,) + 403, + "room appears to have unsupported version %s" % (room_version_prop,), ) - # FIXME + logger.debug("Allowing! %s", event) return + # 3. If event does not have a m.room.create in its auth_events, reject. creation_event = auth_events.get((EventTypes.Create, ""), None) - if not creation_event: raise AuthError(403, "No create event in auth events") + # additional check for m.federate creating_domain = get_domain_from_id(event.room_id) originating_domain = get_domain_from_id(event.sender) if creating_domain != originating_domain: if not _can_federate(event, auth_events): raise AuthError(403, "This room has been marked as unfederatable.") - # FIXME: Temp hack - if event.type == EventTypes.Aliases: + # 4. If type is m.room.aliases + if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth: + # 4a. If event has no state_key, reject if not event.is_state(): raise AuthError(403, "Alias event must be a state event") if not event.state_key: raise AuthError(403, "Alias event must have non-empty state_key") + + # 4b. If sender's domain doesn't matches [sic] state_key, reject sender_domain = get_domain_from_id(event.sender) if event.state_key != sender_domain: raise AuthError( 403, "Alias event's state_key does not match sender's domain" ) + + # 4c. Otherwise, allow. logger.debug("Allowing! %s", event) return @@ -148,15 +182,15 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru _can_send_event(event, auth_events) if event.type == EventTypes.PowerLevels: - _check_power_levels(event, auth_events) + _check_power_levels(room_version_obj, event, auth_events) if event.type == EventTypes.Redaction: - check_redaction(room_version, event, auth_events) + check_redaction(room_version_obj, event, auth_events) logger.debug("Allowing! %s", event) -def _check_size_limits(event): +def _check_size_limits(event: EventBase) -> None: def too_big(field): raise EventSizeError("%s too large" % (field,)) @@ -174,13 +208,18 @@ def _check_size_limits(event): too_big("event") -def _can_federate(event, auth_events): +def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: creation_event = auth_events.get((EventTypes.Create, "")) + # There should always be a creation event, but if not don't federate. + if not creation_event: + return False return creation_event.content.get("m.federate", True) is True -def _is_membership_change_allowed(event, auth_events): +def _is_membership_change_allowed( + event: EventBase, auth_events: StateMap[EventBase] +) -> None: membership = event.content["membership"] # Check if this is the room creator joining: @@ -306,21 +345,25 @@ def _is_membership_change_allowed(event, auth_events): raise AuthError(500, "Unknown membership %s" % membership) -def _check_event_sender_in_room(event, auth_events): +def _check_event_sender_in_room( + event: EventBase, auth_events: StateMap[EventBase] +) -> None: key = (EventTypes.Member, event.user_id) member_event = auth_events.get(key) - return _check_joined_room(member_event, event.user_id, event.room_id) + _check_joined_room(member_event, event.user_id, event.room_id) -def _check_joined_room(member, user_id, room_id): +def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None: if not member or member.membership != Membership.JOIN: raise AuthError( 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) ) -def get_send_level(etype, state_key, power_levels_event): +def get_send_level( + etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase] +) -> int: """Get the power level required to send an event of a given type The federation spec [1] refers to this as "Required Power Level". @@ -328,13 +371,13 @@ def get_send_level(etype, state_key, power_levels_event): https://matrix.org/docs/spec/server_server/unstable.html#definitions Args: - etype (str): type of event - state_key (str|None): state_key of state event, or None if it is not + etype: type of event + state_key: state_key of state event, or None if it is not a state event. - power_levels_event (synapse.events.EventBase|None): power levels event + power_levels_event: power levels event in force at this point in the room Returns: - int: power level required to send this event. + power level required to send this event. """ if power_levels_event: @@ -355,7 +398,7 @@ def get_send_level(etype, state_key, power_levels_event): return int(send_level) -def _can_send_event(event, auth_events): +def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: power_levels_event = _get_power_level_event(auth_events) send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) @@ -377,7 +420,9 @@ def _can_send_event(event, auth_events): return True -def check_redaction(room_version, event, auth_events): +def check_redaction( + room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], +) -> bool: """Check whether the event sender is allowed to redact the target event. Returns: @@ -397,11 +442,7 @@ def check_redaction(room_version, event, auth_events): if user_level >= redact_level: return False - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - raise RuntimeError("Unrecognized room version %r" % (room_version,)) - - if v.event_format == EventFormatVersions.V1: + if room_version_obj.event_format == EventFormatVersions.V1: redacter_domain = get_domain_from_id(event.event_id) redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: @@ -413,7 +454,9 @@ def check_redaction(room_version, event, auth_events): raise AuthError(403, "You don't have permission to redact events") -def _check_power_levels(event, auth_events): +def _check_power_levels( + room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], +) -> None: user_list = event.content.get("users", {}) # Validate users for k, v in user_list.items(): @@ -444,7 +487,7 @@ def _check_power_levels(event, auth_events): ("redact", None), ("kick", None), ("invite", None), - ] + ] # type: List[Tuple[str, Optional[str]]] old_list = current_state.content.get("users", {}) for user in set(list(old_list) + list(user_list)): @@ -455,6 +498,14 @@ def _check_power_levels(event, auth_events): for ev_id in set(list(old_list) + list(new_list)): levels_to_check.append((ev_id, "events")) + # MSC2209 specifies these checks should also be done for the "notifications" + # key. + if room_version_obj.limit_notifications_power_levels: + old_list = current_state.content.get("notifications", {}) + new_list = event.content.get("notifications", {}) + for ev_id in set(list(old_list) + list(new_list)): + levels_to_check.append((ev_id, "notifications")) + old_state = current_state.content new_state = event.content @@ -466,12 +517,12 @@ def _check_power_levels(event, auth_events): new_loc = new_loc.get(dir, {}) if level_to_check in old_loc: - old_level = int(old_loc[level_to_check]) + old_level = int(old_loc[level_to_check]) # type: Optional[int] else: old_level = None if level_to_check in new_loc: - new_level = int(new_loc[level_to_check]) + new_level = int(new_loc[level_to_check]) # type: Optional[int] else: new_level = None @@ -493,26 +544,25 @@ def _check_power_levels(event, auth_events): new_level_too_big = new_level is not None and new_level > user_level if old_level_too_big or new_level_too_big: raise AuthError( - 403, - "You don't have permission to add ops level greater " "than your own", + 403, "You don't have permission to add ops level greater than your own" ) -def _get_power_level_event(auth_events): +def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]: return auth_events.get((EventTypes.PowerLevels, "")) -def get_user_power_level(user_id, auth_events): +def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int: """Get a user's power level Args: - user_id (str): user's id to look up in power_levels - auth_events (dict[(str, str), synapse.events.EventBase]): + user_id: user's id to look up in power_levels + auth_events: state in force at this point in the room (or rather, a subset of it including at least the create event and power levels event. Returns: - int: the user's power level in this room. + the user's power level in this room. """ power_level_event = _get_power_level_event(auth_events) if power_level_event: @@ -538,7 +588,7 @@ def get_user_power_level(user_id, auth_events): return 0 -def _get_named_level(auth_events, name, default): +def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int: power_level_event = _get_power_level_event(auth_events) if not power_level_event: @@ -551,7 +601,7 @@ def _get_named_level(auth_events, name, default): return default -def _verify_third_party_invite(event, auth_events): +def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]): """ Validates that the invite event is authorized by a previous third-party invite. @@ -626,7 +676,7 @@ def get_public_keys(invite_event): return public_keys -def auth_types_for_event(event): +def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]: """Given an event, return a list of (EventType, StateKey) that may be needed to auth the event. The returned list may be a superset of what would actually be required depending on the full state of the room. @@ -635,20 +685,20 @@ def auth_types_for_event(event): actually auth the event. """ if event.type == EventTypes.Create: - return [] + return set() - auth_types = [ + auth_types = { (EventTypes.PowerLevels, ""), (EventTypes.Member, event.sender), (EventTypes.Create, ""), - ] + } if event.type == EventTypes.Member: membership = event.content["membership"] if membership in [Membership.JOIN, Membership.INVITE]: - auth_types.append((EventTypes.JoinRules, "")) + auth_types.add((EventTypes.JoinRules, "")) - auth_types.append((EventTypes.Member, event.state_key)) + auth_types.add((EventTypes.Member, event.state_key)) if membership == Membership.INVITE: if "third_party_invite" in event.content: @@ -656,6 +706,6 @@ def auth_types_for_event(event): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], ) - auth_types.append(key) + auth_types.add(key) return auth_types diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 88ed6d764f..533ba327f5 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,15 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import os from distutils.util import strtobool +from typing import Dict, Optional, Type import six from unpaddedbase64 import encode_base64 -from synapse.api.errors import UnsupportedRoomVersionError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions +from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions +from synapse.types import JsonDict from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -36,34 +39,115 @@ from synapse.util.frozenutils import freeze USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) -class _EventInternalMetadata(object): - def __init__(self, internal_metadata_dict): - self.__dict__ = dict(internal_metadata_dict) +class DictProperty: + """An object property which delegates to the `_dict` within its parent object.""" + + __slots__ = ["key"] + + def __init__(self, key: str): + self.key = key + + def __get__(self, instance, owner=None): + # if the property is accessed as a class property rather than an instance + # property, return the property itself rather than the value + if instance is None: + return self + try: + return instance._dict[self.key] + except KeyError as e1: + # We want this to look like a regular attribute error (mostly so that + # hasattr() works correctly), so we convert the KeyError into an + # AttributeError. + # + # To exclude the KeyError from the traceback, we explicitly + # 'raise from e1.__context__' (which is better than 'raise from None', + # becuase that would omit any *earlier* exceptions). + # + raise AttributeError( + "'%s' has no '%s' property" % (type(instance), self.key) + ) from e1.__context__ + + def __set__(self, instance, v): + instance._dict[self.key] = v + + def __delete__(self, instance): + try: + del instance._dict[self.key] + except KeyError as e1: + raise AttributeError( + "'%s' has no '%s' property" % (type(instance), self.key) + ) from e1.__context__ + + +class DefaultDictProperty(DictProperty): + """An extension of DictProperty which provides a default if the property is + not present in the parent's _dict. + + Note that this means that hasattr() on the property always returns True. + """ + + __slots__ = ["default"] + + def __init__(self, key, default): + super().__init__(key) + self.default = default - def get_dict(self): - return dict(self.__dict__) + def __get__(self, instance, owner=None): + if instance is None: + return self + return instance._dict.get(self.key, self.default) - def is_outlier(self): - return getattr(self, "outlier", False) - def is_out_of_band_membership(self): +class _EventInternalMetadata(object): + __slots__ = ["_dict"] + + def __init__(self, internal_metadata_dict: JsonDict): + # we have to copy the dict, because it turns out that the same dict is + # reused. TODO: fix that + self._dict = dict(internal_metadata_dict) + + outlier = DictProperty("outlier") # type: bool + out_of_band_membership = DictProperty("out_of_band_membership") # type: bool + send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str + recheck_redaction = DictProperty("recheck_redaction") # type: bool + soft_failed = DictProperty("soft_failed") # type: bool + proactively_send = DictProperty("proactively_send") # type: bool + redacted = DictProperty("redacted") # type: bool + txn_id = DictProperty("txn_id") # type: str + token_id = DictProperty("token_id") # type: str + stream_ordering = DictProperty("stream_ordering") # type: int + + # XXX: These are set by StreamWorkerStore._set_before_and_after. + # I'm pretty sure that these are never persisted to the database, so shouldn't + # be here + before = DictProperty("before") # type: str + after = DictProperty("after") # type: str + order = DictProperty("order") # type: int + + def get_dict(self) -> JsonDict: + return dict(self._dict) + + def is_outlier(self) -> bool: + return self._dict.get("outlier", False) + + def is_out_of_band_membership(self) -> bool: """Whether this is an out of band membership, like an invite or an invite rejection. This is needed as those events are marked as outliers, but they still need to be processed as if they're new events (e.g. updating invite state in the database, relaying to clients, etc). """ - return getattr(self, "out_of_band_membership", False) + return self._dict.get("out_of_band_membership", False) - def get_send_on_behalf_of(self): + def get_send_on_behalf_of(self) -> Optional[str]: """Whether this server should send the event on behalf of another server. This is used by the federation "send_join" API to forward the initial join event for a server in the room. returns a str with the name of the server this event is sent on behalf of. """ - return getattr(self, "send_on_behalf_of", None) + return self._dict.get("send_on_behalf_of") - def need_to_check_redaction(self): + def need_to_check_redaction(self) -> bool: """Whether the redaction event needs to be rechecked when fetching from the database. @@ -76,9 +160,9 @@ class _EventInternalMetadata(object): Returns: bool """ - return getattr(self, "recheck_redaction", False) + return self._dict.get("recheck_redaction", False) - def is_soft_failed(self): + def is_soft_failed(self) -> bool: """Whether the event has been soft failed. Soft failed events should be handled as usual, except: @@ -90,7 +174,7 @@ class _EventInternalMetadata(object): Returns: bool """ - return getattr(self, "soft_failed", False) + return self._dict.get("soft_failed", False) def should_proactively_send(self): """Whether the event, if ours, should be sent to other clients and @@ -102,7 +186,7 @@ class _EventInternalMetadata(object): Returns: bool """ - return getattr(self, "proactively_send", True) + return self._dict.get("proactively_send", True) def is_redacted(self): """Whether the event has been redacted. @@ -113,62 +197,53 @@ class _EventInternalMetadata(object): Returns: bool """ - return getattr(self, "redacted", False) - - -def _event_dict_property(key): - # We want to be able to use hasattr with the event dict properties. - # However, (on python3) hasattr expects AttributeError to be raised. Hence, - # we need to transform the KeyError into an AttributeError - def getter(self): - try: - return self._event_dict[key] - except KeyError: - raise AttributeError(key) + return self._dict.get("redacted", False) - def setter(self, v): - try: - self._event_dict[key] = v - except KeyError: - raise AttributeError(key) - - def delete(self): - try: - del self._event_dict[key] - except KeyError: - raise AttributeError(key) - - return property(getter, setter, delete) +class EventBase(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def format_version(self) -> int: + """The EventFormatVersion implemented by this event""" + ... -class EventBase(object): def __init__( self, - event_dict, - signatures={}, - unsigned={}, - internal_metadata_dict={}, - rejected_reason=None, + event_dict: JsonDict, + room_version: RoomVersion, + signatures: Dict[str, Dict[str, str]], + unsigned: JsonDict, + internal_metadata_dict: JsonDict, + rejected_reason: Optional[str], ): + assert room_version.event_format == self.format_version + + self.room_version = room_version self.signatures = signatures self.unsigned = unsigned self.rejected_reason = rejected_reason - self._event_dict = event_dict + self._dict = event_dict self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) - auth_events = _event_dict_property("auth_events") - depth = _event_dict_property("depth") - content = _event_dict_property("content") - hashes = _event_dict_property("hashes") - origin = _event_dict_property("origin") - origin_server_ts = _event_dict_property("origin_server_ts") - prev_events = _event_dict_property("prev_events") - redacts = _event_dict_property("redacts") - room_id = _event_dict_property("room_id") - sender = _event_dict_property("sender") - user_id = _event_dict_property("sender") + auth_events = DictProperty("auth_events") + depth = DictProperty("depth") + content = DictProperty("content") + hashes = DictProperty("hashes") + origin = DictProperty("origin") + origin_server_ts = DictProperty("origin_server_ts") + prev_events = DictProperty("prev_events") + redacts = DefaultDictProperty("redacts", None) + room_id = DictProperty("room_id") + sender = DictProperty("sender") + state_key = DictProperty("state_key") + type = DictProperty("type") + user_id = DictProperty("sender") + + @property + def event_id(self) -> str: + raise NotImplementedError() @property def membership(self): @@ -177,19 +252,19 @@ class EventBase(object): def is_state(self): return hasattr(self, "state_key") and self.state_key is not None - def get_dict(self): - d = dict(self._event_dict) + def get_dict(self) -> JsonDict: + d = dict(self._dict) d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) return d def get(self, key, default=None): - return self._event_dict.get(key, default) + return self._dict.get(key, default) def get_internal_metadata_dict(self): return self.internal_metadata.get_dict() - def get_pdu_json(self, time_now=None): + def get_pdu_json(self, time_now=None) -> JsonDict: pdu_json = self.get_dict() if time_now is not None and "age_ts" in pdu_json["unsigned"]: @@ -206,16 +281,16 @@ class EventBase(object): raise AttributeError("Unrecognized attribute %s" % (instance,)) def __getitem__(self, field): - return self._event_dict[field] + return self._dict[field] def __contains__(self, field): - return field in self._event_dict + return field in self._dict def items(self): - return list(self._event_dict.items()) + return list(self._dict.items()) def keys(self): - return six.iterkeys(self._event_dict) + return six.iterkeys(self._dict) def prev_event_ids(self): """Returns the list of prev event IDs. The order matches the order @@ -239,7 +314,13 @@ class EventBase(object): class FrozenEvent(EventBase): format_version = EventFormatVersions.V1 # All events of this type are V1 - def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): + def __init__( + self, + event_dict: JsonDict, + room_version: RoomVersion, + internal_metadata_dict: JsonDict = {}, + rejected_reason: Optional[str] = None, + ): event_dict = dict(event_dict) # Signatures is a dict of dicts, and this is faster than doing a @@ -260,19 +341,21 @@ class FrozenEvent(EventBase): else: frozen_dict = event_dict - self.event_id = event_dict["event_id"] - self.type = event_dict["type"] - if "state_key" in event_dict: - self.state_key = event_dict["state_key"] + self._event_id = event_dict["event_id"] - super(FrozenEvent, self).__init__( + super().__init__( frozen_dict, + room_version=room_version, signatures=signatures, unsigned=unsigned, internal_metadata_dict=internal_metadata_dict, rejected_reason=rejected_reason, ) + @property + def event_id(self) -> str: + return self._event_id + def __str__(self): return self.__repr__() @@ -287,7 +370,13 @@ class FrozenEvent(EventBase): class FrozenEventV2(EventBase): format_version = EventFormatVersions.V2 # All events of this type are V2 - def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): + def __init__( + self, + event_dict: JsonDict, + room_version: RoomVersion, + internal_metadata_dict: JsonDict = {}, + rejected_reason: Optional[str] = None, + ): event_dict = dict(event_dict) # Signatures is a dict of dicts, and this is faster than doing a @@ -311,12 +400,10 @@ class FrozenEventV2(EventBase): frozen_dict = event_dict self._event_id = None - self.type = event_dict["type"] - if "state_key" in event_dict: - self.state_key = event_dict["state_key"] - super(FrozenEventV2, self).__init__( + super().__init__( frozen_dict, + room_version=room_version, signatures=signatures, unsigned=unsigned, internal_metadata_dict=internal_metadata_dict, @@ -383,28 +470,7 @@ class FrozenEventV3(FrozenEventV2): return self._event_id -def room_version_to_event_format(room_version): - """Converts a room version string to the event format - - Args: - room_version (str) - - Returns: - int - - Raises: - UnsupportedRoomVersionError if the room version is unknown - """ - v = KNOWN_ROOM_VERSIONS.get(room_version) - - if not v: - # this can happen if support is withdrawn for a room version - raise UnsupportedRoomVersionError() - - return v.event_format - - -def event_type_from_format_version(format_version): +def _event_type_from_format_version(format_version: int) -> Type[EventBase]: """Returns the python type to use to construct an Event object for the given event format version. @@ -424,3 +490,14 @@ def event_type_from_format_version(format_version): return FrozenEventV3 else: raise Exception("No event format %r" % (format_version,)) + + +def make_event_from_dict( + event_dict: JsonDict, + room_version: RoomVersion = RoomVersions.V1, + internal_metadata_dict: JsonDict = {}, + rejected_reason: Optional[str] = None, +) -> EventBase: + """Construct an EventBase from the given event dict""" + event_type = _event_type_from_format_version(room_version.event_format) + return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 3997751337..a0c4a40c27 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,8 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional import attr +from nacl.signing import SigningKey from twisted.internet import defer @@ -23,13 +25,14 @@ from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, KNOWN_ROOM_VERSIONS, EventFormatVersions, + RoomVersion, ) from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.types import EventID +from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.types import EventID, JsonDict +from synapse.util import Clock from synapse.util.stringutils import random_string -from . import _EventInternalMetadata, event_type_from_format_version - @attr.s(slots=True, cmp=False, frozen=True) class EventBuilder(object): @@ -40,7 +43,7 @@ class EventBuilder(object): content/unsigned/internal_metadata fields are still mutable) Attributes: - format_version (int): Event format version + room_version: Version of the target room room_id (str) type (str) sender (str) @@ -63,7 +66,7 @@ class EventBuilder(object): _hostname = attr.ib() _signing_key = attr.ib() - format_version = attr.ib() + room_version = attr.ib(type=RoomVersion) room_id = attr.ib() type = attr.ib() @@ -108,7 +111,8 @@ class EventBuilder(object): ) auth_ids = yield self._auth.compute_auth_events(self, state_ids) - if self.format_version == EventFormatVersions.V1: + format_version = self.room_version.event_format + if format_version == EventFormatVersions.V1: auth_events = yield self._store.add_event_hashes(auth_ids) prev_events = yield self._store.add_event_hashes(prev_event_ids) else: @@ -148,7 +152,7 @@ class EventBuilder(object): clock=self._clock, hostname=self._hostname, signing_key=self._signing_key, - format_version=self.format_version, + room_version=self.room_version, event_dict=event_dict, internal_metadata_dict=self.internal_metadata.get_dict(), ) @@ -201,7 +205,7 @@ class EventBuilderFactory(object): clock=self.clock, hostname=self.hostname, signing_key=self.signing_key, - format_version=room_version.event_format, + room_version=room_version, type=key_values["type"], state_key=key_values.get("state_key"), room_id=key_values["room_id"], @@ -214,29 +218,19 @@ class EventBuilderFactory(object): def create_local_event_from_event_dict( - clock, - hostname, - signing_key, - format_version, - event_dict, - internal_metadata_dict=None, -): + clock: Clock, + hostname: str, + signing_key: SigningKey, + room_version: RoomVersion, + event_dict: JsonDict, + internal_metadata_dict: Optional[JsonDict] = None, +) -> EventBase: """Takes a fully formed event dict, ensuring that fields like `origin` and `origin_server_ts` have correct values for a locally produced event, then signs and hashes it. - - Args: - clock (Clock) - hostname (str) - signing_key - format_version (int) - event_dict (dict) - internal_metadata_dict (dict|None) - - Returns: - FrozenEvent """ + format_version = room_version.event_format if format_version not in KNOWN_EVENT_FORMAT_VERSIONS: raise Exception("No event format defined for version %r" % (format_version,)) @@ -257,9 +251,9 @@ def create_local_event_from_event_dict( event_dict.setdefault("signatures", {}) - add_hashes_and_signatures(event_dict, hostname, signing_key) - return event_type_from_format_version(format_version)( - event_dict, internal_metadata_dict=internal_metadata_dict + add_hashes_and_signatures(room_version, event_dict, hostname, signing_key) + return make_event_from_dict( + event_dict, room_version, internal_metadata_dict=internal_metadata_dict ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index acbcbeeced..7c5f620d09 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,104 +12,124 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union from six import iteritems +import attr from frozendict import frozendict from twisted.internet import defer +from synapse.appservice import ApplicationService from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import StateMap -class EventContext(object): +@attr.s(slots=True) +class EventContext: """ + Holds information relevant to persisting an event + Attributes: - state_group (int|None): state group id, if the state has been stored - as a state group. This is usually only None if e.g. the event is - an outlier. - rejected (bool|str): A rejection reason if the event was rejected, else - False - - push_actions (list[(str, list[object])]): list of (user_id, actions) - tuples - - prev_group (int): Previously persisted state group. ``None`` for an - outlier. - delta_ids (dict[(str, str), str]): Delta from ``prev_group``. - (type, state_key) -> event_id. ``None`` for an outlier. - - prev_state_events (?): XXX: is this ever set to anything other than - the empty list? - - _current_state_ids (dict[(str, str), str]|None): - The current state map including the current event. None if outlier - or we haven't fetched the state from DB yet. - (type, state_key) -> event_id + rejected: A rejection reason if the event was rejected, else False + + _state_group: The ID of the state group for this event. Note that state events + are persisted with a state group which includes the new event, so this is + effectively the state *after* the event in question. + + For a *rejected* state event, where the state of the rejected event is + ignored, this state_group should never make it into the + event_to_state_groups table. Indeed, inspecting this value for a rejected + state event is almost certainly incorrect. + + For an outlier, where we don't have the state at the event, this will be + None. + + Note that this is a private attribute: it should be accessed via + the ``state_group`` property. + + state_group_before_event: The ID of the state group representing the state + of the room before this event. + + If this is a non-state event, this will be the same as ``state_group``. If + it's a state event, it will be the same as ``prev_group``. + + If ``state_group`` is None (ie, the event is an outlier), + ``state_group_before_event`` will always also be ``None``. + + prev_group: If it is known, ``state_group``'s prev_group. Note that this being + None does not necessarily mean that ``state_group`` does not have + a prev_group! + + If the event is a state event, this is normally the same as ``prev_group``. + + If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` + will always also be ``None``. + + Note that this *not* (necessarily) the state group associated with + ``_prev_state_ids``. + + delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` + and ``state_group``. + + app_service: If this event is being sent by a (local) application service, that + app service. + + _current_state_ids: The room state map, including this event - ie, the state + in ``state_group``. - _prev_state_ids (dict[(str, str), str]|None): - The current state map excluding the current event. None if outlier - or we haven't fetched the state from DB yet. (type, state_key) -> event_id - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet + FIXME: what is this for an outlier? it seems ill-defined. It seems like + it could be either {}, or the state we were given by the remote + server, depending on $THINGS - _event_type (str): The type of the event the context is associated with. - Only set when state has not been fetched yet. + Note that this is a private attribute: it should be accessed via + ``get_current_state_ids``. _AsyncEventContext impl calculates this + on-demand: it will be None until that happens. - _event_state_key (str|None): The state_key of the event the context is - associated with. Only set when state has not been fetched yet. + _prev_state_ids: The room state map, excluding this event - ie, the state + in ``state_group_before_event``. For a non-state + event, this will be the same as _current_state_events. - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. - Only set when state has not been fetched yet. - """ + Note that it is a completely different thing to prev_group! - __slots__ = [ - "state_group", - "rejected", - "prev_group", - "delta_ids", - "prev_state_events", - "app_service", - "_current_state_ids", - "_prev_state_ids", - "_prev_state_id", - "_event_type", - "_event_state_key", - "_fetching_state_deferred", - ] - - def __init__(self): - self.prev_state_events = [] - self.rejected = False - self.app_service = None + (type, state_key) -> event_id - @staticmethod - def with_state( - state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None - ): - context = EventContext() + FIXME: again, what is this for an outlier? - # The current state including the current event - context._current_state_ids = current_state_ids - # The current state excluding the current event - context._prev_state_ids = prev_state_ids - context.state_group = state_group + As with _current_state_ids, this is a private attribute. It should be + accessed via get_prev_state_ids. + """ - context._prev_state_id = None - context._event_type = None - context._event_state_key = None - context._fetching_state_deferred = defer.succeed(None) + rejected = attr.ib(default=False, type=Union[bool, str]) + _state_group = attr.ib(default=None, type=Optional[int]) + state_group_before_event = attr.ib(default=None, type=Optional[int]) + prev_group = attr.ib(default=None, type=Optional[int]) + delta_ids = attr.ib(default=None, type=Optional[StateMap[str]]) + app_service = attr.ib(default=None, type=Optional[ApplicationService]) - # A previously persisted state group and a delta between that - # and this state. - context.prev_group = prev_group - context.delta_ids = delta_ids + _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) + _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) - return context + @staticmethod + def with_state( + state_group, + state_group_before_event, + current_state_ids, + prev_state_ids, + prev_group=None, + delta_ids=None, + ): + return EventContext( + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + state_group=state_group, + state_group_before_event=state_group_before_event, + prev_group=prev_group, + delta_ids=delta_ids, + ) @defer.inlineCallbacks def serialize(self, event, store): @@ -128,7 +148,7 @@ class EventContext(object): # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_ids = yield self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -137,74 +157,93 @@ class EventContext(object): "prev_state_id": prev_state_id, "event_type": event.type, "event_state_key": event.state_key if event.is_state() else None, - "state_group": self.state_group, + "state_group": self._state_group, + "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, "prev_group": self.prev_group, "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, "app_service_id": self.app_service.id if self.app_service else None, } @staticmethod - def deserialize(store, input): + def deserialize(storage, input): """Converts a dict that was produced by `serialize` back into a EventContext. Args: - store (DataStore): Used to convert AS ID to AS object + storage (Storage): Used to convert AS ID to AS object and fetch + state. input (dict): A dict produced by `serialize` Returns: EventContext """ - context = EventContext() + context = _AsyncEventContextImpl( + # We use the state_group and prev_state_id stuff to pull the + # current_state_ids out of the DB and construct prev_state_ids. + storage=storage, + prev_state_id=input["prev_state_id"], + event_type=input["event_type"], + event_state_key=input["event_state_key"], + state_group=input["state_group"], + state_group_before_event=input["state_group_before_event"], + prev_group=input["prev_group"], + delta_ids=_decode_state_dict(input["delta_ids"]), + rejected=input["rejected"], + ) - # We use the state_group and prev_state_id stuff to pull the - # current_state_ids out of the DB and construct prev_state_ids. - context._prev_state_id = input["prev_state_id"] - context._event_type = input["event_type"] - context._event_state_key = input["event_state_key"] + app_service_id = input["app_service_id"] + if app_service_id: + context.app_service = storage.main.get_app_service_by_id(app_service_id) - context._current_state_ids = None - context._prev_state_ids = None - context._fetching_state_deferred = None + return context - context.state_group = input["state_group"] - context.prev_group = input["prev_group"] - context.delta_ids = _decode_state_dict(input["delta_ids"]) + @property + def state_group(self) -> Optional[int]: + """The ID of the state group for this event. - context.rejected = input["rejected"] - context.prev_state_events = input["prev_state_events"] + Note that state events are persisted with a state group which includes the new + event, so this is effectively the state *after* the event in question. - app_service_id = input["app_service_id"] - if app_service_id: - context.app_service = store.get_app_service_by_id(app_service_id) + For an outlier, where we don't have the state at the event, this will be None. - return context + It is an error to access this for a rejected event, since rejected state should + not make it into the room state. Accessing this property will raise an exception + if ``rejected`` is set. + """ + if self.rejected: + raise RuntimeError("Attempt to access state_group of rejected event") + + return self._state_group @defer.inlineCallbacks - def get_current_state_ids(self, store): - """Gets the current state IDs + def get_current_state_ids(self): + """ + Gets the room state map, including this event - ie, the state in ``state_group`` + + It is an error to access this for a rejected event, since rejected state should + not make it into the room state. This method will raise an exception if + ``rejected`` is set. Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group is None, which happens when the associated event is an outlier. + Maps a (type, state_key) to the event ID of the state event matching this tuple. """ + if self.rejected: + raise RuntimeError("Attempt to access state_ids of rejected event") - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched() return self._current_state_ids @defer.inlineCallbacks - def get_prev_state_ids(self, store): - """Gets the prev state IDs + def get_prev_state_ids(self): + """ + Gets the room state map, excluding this event. + + For a non-state event, this will be the same as get_current_state_ids(). Returns: Deferred[dict[(str, str), str]|None]: Returns None if state_group @@ -212,65 +251,88 @@ class EventContext(object): Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) - - yield make_deferred_yieldable(self._fetching_state_deferred) - + yield self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): """Gets the current state IDs if we have them already cached. + It is an error to access this for a rejected event, since rejected state should + not make it into the room state. This method will raise an exception if + ``rejected`` is set. + Returns: dict[(str, str), str]|None: Returns None if we haven't cached the state or if state_group is None, which happens when the associated event is an outlier. """ + if self.rejected: + raise RuntimeError("Attempt to access state_ids of rejected event") return self._current_state_ids + def _ensure_fetched(self): + return defer.succeed(None) + + +@attr.s(slots=True) +class _AsyncEventContextImpl(EventContext): + """ + An implementation of EventContext which fetches _current_state_ids and + _prev_state_ids from the database on demand. + + Attributes: + + _storage (Storage) + + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have + been calculated. None if we haven't started calculating yet + + _event_type (str): The type of the event the context is associated with. + + _event_state_key (str): The state_key of the event the context is + associated with. + + _prev_state_id (str|None): If the event associated with the context is + a state event, then `_prev_state_id` is the event_id of the state + that was replaced. + """ + + # This needs to have a default as we're inheriting + _storage = attr.ib(default=None) + _prev_state_id = attr.ib(default=None) + _event_type = attr.ib(default=None) + _event_state_key = attr.ib(default=None) + _fetching_state_deferred = attr.ib(default=None) + + def _ensure_fetched(self): + if not self._fetching_state_deferred: + self._fetching_state_deferred = run_in_background(self._fill_out_state) + + return make_deferred_yieldable(self._fetching_state_deferred) + @defer.inlineCallbacks - def _fill_out_state(self, store): + def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) - if self._prev_state_id and self._event_state_key is not None: + self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self.state_group + ) + if self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) key = (self._event_type, self._event_state_key) - self._prev_state_ids[key] = self._prev_state_id + if self._prev_state_id: + self._prev_state_ids[key] = self._prev_state_id + else: + self._prev_state_ids.pop(key, None) else: self._prev_state_ids = self._current_state_ids - @defer.inlineCallbacks - def update_state( - self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids - ): - """Replace the state in the context - """ - - # We need to make sure we wait for any ongoing fetching of state - # to complete so that the updated state doesn't get clobbered - if self._fetching_state_deferred: - yield make_deferred_yieldable(self._fetching_state_deferred) - - self.state_group = state_group - self._prev_state_ids = prev_state_ids - self.prev_group = prev_group - self._current_state_ids = current_state_ids - self.delta_ids = delta_ids - - # We need to ensure that that we've marked as having fetched the state - self._fetching_state_deferred = defer.succeed(None) - def _encode_state_dict(state_dict): """Since dicts of (type, state_key) -> event_id cannot be serialized in diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 129771f183..1ffc9525d1 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,22 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import Any, Dict, List -class SpamChecker(object): - def __init__(self, hs): - self.spam_checker = None +from synapse.spam_checker_api import SpamCheckerApi - module = None - config = None - try: - module, config = hs.config.spam_checker - except Exception: - pass +MYPY = False +if MYPY: + import synapse.server - if module is not None: - self.spam_checker = module(config=config) - def check_event_for_spam(self, event): +class SpamChecker(object): + def __init__(self, hs: "synapse.server.HomeServer"): + self.spam_checkers = [] # type: List[Any] + + for module, config in hs.config.spam_checkers: + # Older spam checkers don't accept the `api` argument, so we + # try and detect support. + spam_args = inspect.getfullargspec(module) + if "api" in spam_args.args: + api = SpamCheckerApi(hs) + self.spam_checkers.append(module(config=config, api=api)) + else: + self.spam_checkers.append(module(config=config)) + + def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -36,80 +46,117 @@ class SpamChecker(object): users receive a blank event. Args: - event (synapse.events.EventBase): the event to be checked + event: the event to be checked Returns: - bool: True if the event is spammy. + True if the event is spammy. """ - if self.spam_checker is None: - return False + for spam_checker in self.spam_checkers: + if spam_checker.check_event_for_spam(event): + return True - return self.spam_checker.check_event_for_spam(event) + return False - def user_may_invite(self, inviter_userid, invitee_userid, room_id): + def user_may_invite( + self, inviter_userid: str, invitee_userid: str, room_id: str + ) -> bool: """Checks if a given user may send an invite If this method returns false, the invite will be rejected. Args: - userid (string): The sender's user ID + inviter_userid: The user ID of the sender of the invitation + invitee_userid: The user ID targeted in the invitation + room_id: The room ID Returns: - bool: True if the user may send an invite, otherwise False + True if the user may send an invite, otherwise False """ - if self.spam_checker is None: - return True + for spam_checker in self.spam_checkers: + if ( + spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + is False + ): + return False - return self.spam_checker.user_may_invite( - inviter_userid, invitee_userid, room_id - ) + return True - def user_may_create_room(self, userid): + def user_may_create_room(self, userid: str) -> bool: """Checks if a given user may create a room If this method returns false, the creation request will be rejected. Args: - userid (string): The sender's user ID + userid: The ID of the user attempting to create a room Returns: - bool: True if the user may create a room, otherwise False + True if the user may create a room, otherwise False """ - if self.spam_checker is None: - return True + for spam_checker in self.spam_checkers: + if spam_checker.user_may_create_room(userid) is False: + return False - return self.spam_checker.user_may_create_room(userid) + return True - def user_may_create_room_alias(self, userid, room_alias): + def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: """Checks if a given user may create a room alias If this method returns false, the association request will be rejected. Args: - userid (string): The sender's user ID - room_alias (string): The alias to be created + userid: The ID of the user attempting to create a room alias + room_alias: The alias to be created Returns: - bool: True if the user may create a room alias, otherwise False + True if the user may create a room alias, otherwise False """ - if self.spam_checker is None: - return True + for spam_checker in self.spam_checkers: + if spam_checker.user_may_create_room_alias(userid, room_alias) is False: + return False - return self.spam_checker.user_may_create_room_alias(userid, room_alias) + return True - def user_may_publish_room(self, userid, room_id): + def user_may_publish_room(self, userid: str, room_id: str) -> bool: """Checks if a given user may publish a room to the directory If this method returns false, the publish request will be rejected. Args: - userid (string): The sender's user ID - room_id (string): The ID of the room that would be published + userid: The user ID attempting to publish the room + room_id: The ID of the room that would be published Returns: - bool: True if the user may publish the room, otherwise False + True if the user may publish the room, otherwise False """ - if self.spam_checker is None: - return True + for spam_checker in self.spam_checkers: + if spam_checker.user_may_publish_room(userid, room_id) is False: + return False + + return True + + def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + """Checks if a user ID or display name are considered "spammy" by this server. + + If the server considers a username spammy, then it will not be included in + user directory results. - return self.spam_checker.user_may_publish_room(userid, room_id) + Args: + user_profile: The user information to check, it contains the keys: + * user_id + * display_name + * avatar_url + + Returns: + True if the user is spammy. + """ + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_username_for_spam", None) + if checker: + # Make a copy of the user profile object to ensure the spam checker + # cannot modify it. + if checker(user_profile.copy()): + return True + + return False diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 714a9b1579..459132d388 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -53,7 +53,7 @@ class ThirdPartyEventRules(object): if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} @@ -74,15 +74,16 @@ class ThirdPartyEventRules(object): is_requester_admin (bool): If the requester is an admin Returns: - defer.Deferred + defer.Deferred[bool]: Whether room creation is allowed or denied. """ if self.third_party_rules is None: - return + return True - yield self.third_party_rules.on_create_room( + ret = yield self.third_party_rules.on_create_room( requester, config, is_requester_admin ) + return ret @defer.inlineCallbacks def check_threepid_can_be_invited(self, medium, address, room_id): diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 07d1c5bcf0..dd340be9a7 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections import re +from typing import Any, Mapping, Union from six import string_types @@ -22,6 +23,8 @@ from frozendict import frozendict from twisted.internet import defer from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -34,26 +37,20 @@ from . import EventBase SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.") -def prune_event(event): +def prune_event(event: EventBase) -> EventBase: """ Returns a pruned version of the given event, which removes all keys we don't know about or think could potentially be dodgy. This is used when we "redact" an event. We want to remove all fields that the user has specified, but we do want to keep necessary information like type, state_key etc. - - Args: - event (FrozenEvent) - - Returns: - FrozenEvent """ - pruned_event_dict = prune_event_dict(event.get_dict()) + pruned_event_dict = prune_event_dict(event.room_version, event.get_dict()) - from . import event_type_from_format_version + from . import make_event_from_dict - pruned_event = event_type_from_format_version(event.format_version)( - pruned_event_dict, event.internal_metadata.get_dict() + pruned_event = make_event_from_dict( + pruned_event_dict, event.room_version, event.internal_metadata.get_dict() ) # Mark the event as redacted @@ -62,15 +59,12 @@ def prune_event(event): return pruned_event -def prune_event_dict(event_dict): +def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: """Redacts the event_dict in the same way as `prune_event`, except it operates on dicts rather than event objects - Args: - event_dict (dict) - Returns: - dict: A copy of the pruned event dict + A copy of the pruned event dict """ allowed_keys = [ @@ -117,7 +111,7 @@ def prune_event_dict(event_dict): "kick", "redact", ) - elif event_type == EventTypes.Aliases: + elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: add_fields("aliases") elif event_type == EventTypes.RoomHistoryVisibility: add_fields("history_visibility") @@ -422,3 +416,69 @@ class EventClientSerializer(object): return yieldable_gather_results( self.serialize_event, events, time_now=time_now, **kwargs ) + + +def copy_power_levels_contents( + old_power_levels: Mapping[str, Union[int, Mapping[str, int]]] +): + """Copy the content of a power_levels event, unfreezing frozendicts along the way + + Raises: + TypeError if the input does not look like a valid power levels event content + """ + if not isinstance(old_power_levels, collections.Mapping): + raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) + + power_levels = {} + for k, v in old_power_levels.items(): + + if isinstance(v, int): + power_levels[k] = v + continue + + if isinstance(v, collections.Mapping): + power_levels[k] = h = {} + for k1, v1 in v.items(): + # we should only have one level of nesting + if not isinstance(v1, int): + raise TypeError( + "Invalid power_levels value for %s.%s: %r" % (k, k1, v1) + ) + h[k1] = v1 + continue + + raise TypeError("Invalid power_levels value for %s: %r" % (k, v)) + + return power_levels + + +def validate_canonicaljson(value: Any): + """ + Ensure that the JSON object is valid according to the rules of canonical JSON. + + See the appendix section 3.1: Canonical JSON. + + This rejects JSON that has: + * An integer outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] + * Floats + * NaN, Infinity, -Infinity + """ + if isinstance(value, int): + if value <= -(2 ** 53) or 2 ** 53 <= value: + raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) + + elif isinstance(value, float): + # Note that Infinity, -Infinity, and NaN are also considered floats. + raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON) + + elif isinstance(value, (dict, frozendict)): + for v in value.values(): + validate_canonicaljson(v) + + elif isinstance(value, (list, tuple)): + for i in value: + validate_canonicaljson(i) + + elif not isinstance(value, (bool, str)) and value is not None: + # Other potential JSON values (bool, None, str) are safe. + raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 272426e105..b001c64bb4 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import integer_types, string_types from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions +from synapse.events.utils import validate_canonicaljson from synapse.types import EventID, RoomID, UserID class EventValidator(object): - def validate_new(self, event): + def validate_new(self, event, config): """Validates the event has roughly the right format Args: - event (FrozenEvent) + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. """ self.validate_builder(event) @@ -54,6 +56,12 @@ class EventValidator(object): if not isinstance(getattr(event, s), string_types): raise SynapseError(400, "'%s' not a string type" % (s,)) + # Depending on the room version, ensure the data is spec compliant JSON. + if event.room_version.strict_canonicaljson: + # Note that only the client controlled portion of the event is + # checked, since we trust the portions of the event we created. + validate_canonicaljson(event.content) + if event.type == EventTypes.Aliases: if "aliases" in event.content: for alias in event.content["aliases"]: @@ -67,6 +75,99 @@ class EventValidator(object): Codes.INVALID_PARAM, ) + if event.type == EventTypes.Retention: + self._validate_retention(event, config) + + def _validate_retention(self, event, config): + """Checks that an event that defines the retention policy for a room respects the + boundaries imposed by the server's administrator. + + Args: + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. + """ + min_lifetime = event.content.get("min_lifetime") + max_lifetime = event.content.get("max_lifetime") + + if min_lifetime is not None: + if not isinstance(min_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'min_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and min_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be lower than the minimum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and min_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if max_lifetime is not None: + if not isinstance(max_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'max_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and max_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be lower than the minimum allowed value" + " enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and max_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + min_lifetime is not None + and max_lifetime is not None + and min_lifetime > max_lifetime + ): + raise SynapseError( + code=400, + msg="'min_lifetime' can't be greater than 'max_lifetime", + errcode=Codes.BAD_JSON, + ) + def validate_builder(self, event): """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 5a1e23a145..c0012c6872 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,27 +15,28 @@ # limitations under the License. import logging from collections import namedtuple +from typing import Iterable, List import six from twisted.internet import defer -from twisted.internet.defer import DeferredList +from twisted.internet.defer import Deferred, DeferredList +from twisted.python.failure import Failure from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions +from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.crypto.event_signing import check_event_content_hash -from synapse.events import event_type_from_format_version -from synapse.events.utils import prune_event +from synapse.crypto.keyring import Keyring +from synapse.events import EventBase, make_event_from_dict +from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, - preserve_fn, ) -from synapse.types import get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.types import JsonDict, get_domain_from_id logger = logging.getLogger(__name__) @@ -49,92 +51,25 @@ class FederationBase(object): self.store = hs.get_datastore() self._clock = hs.get_clock() - @defer.inlineCallbacks - def _check_sigs_and_hash_and_fetch( - self, origin, pdus, room_version, outlier=False, include_none=False - ): - """Takes a list of PDUs and checks the signatures and hashs of each - one. If a PDU fails its signature check then we check if we have it in - the database and if not then request if from the originating server of - that PDU. - - If a PDU fails its content hash check then it is redacted. - - The given list of PDUs are not modified, instead the function returns - a new list. - - Args: - origin (str) - pdu (list) - room_version (str) - outlier (bool): Whether the events are outliers or not - include_none (str): Whether to include None in the returned list - for events that have failed their checks - - Returns: - Deferred : A list of PDUs that have valid signatures and hashes. - """ - deferreds = self._check_sigs_and_hashes(room_version, pdus) - - @defer.inlineCallbacks - def handle_check_result(pdu, deferred): - try: - res = yield make_deferred_yieldable(deferred) - except SynapseError: - res = None - - if not res: - # Check local db. - res = yield self.store.get_event( - pdu.event_id, allow_rejected=True, allow_none=True - ) - - if not res and pdu.origin != origin: - try: - res = yield self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - room_version=room_version, - outlier=outlier, - timeout=10000, - ) - except SynapseError: - pass - - if not res: - logger.warn( - "Failed to find copy of %s with valid signature", pdu.event_id - ) - - return res - - handle = preserve_fn(handle_check_result) - deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] - - valid_pdus = yield make_deferred_yieldable( - defer.gatherResults(deferreds2, consumeErrors=True) - ).addErrback(unwrapFirstError) - - if include_none: - return valid_pdus - else: - return [p for p in valid_pdus if p] - - def _check_sigs_and_hash(self, room_version, pdu): + def _check_sigs_and_hash( + self, room_version: RoomVersion, pdu: EventBase + ) -> Deferred: return make_deferred_yieldable( self._check_sigs_and_hashes(room_version, [pdu])[0] ) - def _check_sigs_and_hashes(self, room_version, pdus): + def _check_sigs_and_hashes( + self, room_version: RoomVersion, pdus: List[EventBase] + ) -> List[Deferred]: """Checks that each of the received events is correctly signed by the sending server. Args: - room_version (str): The room version of the PDUs - pdus (list[FrozenEvent]): the events to be checked + room_version: The room version of the PDUs + pdus: the events to be checked Returns: - list[Deferred]: for each input event, a deferred which: + For each input event, a deferred which: * returns the original event if the checks pass * returns a redacted version of the event (if the signature matched but the hash did not) @@ -143,9 +78,9 @@ class FederationBase(object): """ deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - ctx = LoggingContext.current_context() + ctx = current_context() - def callback(_, pdu): + def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): # let's try to distinguish between failures because the event was @@ -173,7 +108,7 @@ class FederationBase(object): return redacted_event if self.spam_checker.check_event_for_spam(pdu): - logger.warn( + logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json(), @@ -182,10 +117,10 @@ class FederationBase(object): return pdu - def errback(failure, pdu): + def errback(failure: Failure, pdu: EventBase): failure.trap(SynapseError) with PreserveLoggingContext(ctx): - logger.warn( + logger.warning( "Signature check failed for %s: %s", pdu.event_id, failure.getErrorMessage(), @@ -208,16 +143,18 @@ class PduToCheckSig( pass -def _check_sigs_on_pdus(keyring, room_version, pdus): +def _check_sigs_on_pdus( + keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase] +) -> List[Deferred]: """Check that the given events are correctly signed Args: - keyring (synapse.crypto.Keyring): keyring object to do the checks - room_version (str): the room version of the PDUs - pdus (Collection[EventBase]): the events to be checked + keyring: keyring object to do the checks + room_version: the room version of the PDUs + pdus: the events to be checked Returns: - List[Deferred]: a Deferred for each event in pdus, which will either succeed if + A Deferred for each event in pdus, which will either succeed if the signatures are valid, or fail (with a SynapseError) if not. """ @@ -252,10 +189,6 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): for p in pdus ] - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - raise RuntimeError("Unrecognized room version %s" % (room_version,)) - # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] @@ -265,7 +198,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): ( p.sender_domain, p.redacted_pdu_json, - p.pdu.origin_server_ts if v.enforce_key_validity else 0, + p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, p.pdu.event_id, ) for p in pdus_to_check_sender @@ -278,9 +211,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): pdu_to_check.sender_domain, e.getErrorMessage(), ) - # XX not really sure if these are the right codes, but they are what - # we've done for ages - raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) for p, d in zip(pdus_to_check_sender, more_deferreds): d.addErrback(sender_err, p) @@ -290,7 +221,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain # (ie, the room version uses old-style non-hash event IDs). - if v.event_format == EventFormatVersions.V1: + if room_version.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ p for p in pdus_to_check @@ -302,7 +233,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): ( get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, - p.pdu.origin_server_ts if v.enforce_key_validity else 0, + p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, p.pdu.event_id, ) for p in pdus_to_check_event_id @@ -314,8 +245,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): "event id %s: unable to verify signature for event id domain: %s" % (pdu_to_check.pdu.event_id, e.getErrorMessage()) ) - # XX as above: not really sure if these are the right codes - raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) for p, d in zip(pdus_to_check_event_id, more_deferreds): d.addErrback(event_err, p) @@ -325,7 +255,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] -def _flatten_deferred_list(deferreds): +def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred: """Given a list of deferreds, either return the single deferred, combine into a DeferredList, or return an already resolved deferred. """ @@ -337,7 +267,7 @@ def _flatten_deferred_list(deferreds): return defer.succeed(None) -def _is_invite_via_3pid(event): +def _is_invite_via_3pid(event: EventBase) -> bool: return ( event.type == EventTypes.Member and event.membership == Membership.INVITE @@ -345,16 +275,15 @@ def _is_invite_via_3pid(event): ) -def event_from_pdu_json(pdu_json, event_format_version, outlier=False): - """Construct a FrozenEvent from an event json received over federation +def event_from_pdu_json( + pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False +) -> EventBase: + """Construct an EventBase from an event json received over federation Args: - pdu_json (object): pdu as received over federation - event_format_version (int): The event format version - outlier (bool): True to mark this event as an outlier - - Returns: - FrozenEvent + pdu_json: pdu as received over federation + room_version: The version of the room this event belongs to + outlier: True to mark this event as an outlier Raises: SynapseError: if the pdu is missing required fields or is otherwise @@ -373,8 +302,11 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False): elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(event_format_version)(pdu_json) + # Validate that the JSON conforms to the specification. + if room_version.strict_canonicaljson: + validate_canonicaljson(pdu_json) + event = make_event_from_dict(pdu_json, room_version) event.internal_metadata.outlier = outlier return event diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6ee6216660..687cd841ac 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -17,12 +17,23 @@ import copy import itertools import logging - -from six.moves import range +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) from prometheus_client import Counter from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( @@ -31,16 +42,19 @@ from synapse.api.errors import ( FederationDeniedError, HttpResponseException, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, + RoomVersion, RoomVersions, ) -from synapse.events import builder, room_version_to_event_format +from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.utils import log_function +from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -52,6 +66,8 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t PDU_RETRY_TIME_MS = 1 * 60 * 1000 +T = TypeVar("T") + class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response @@ -170,56 +186,54 @@ class FederationClient(FederationBase): sent_queries_counter.labels("client_one_time_keys").inc() return self.transport_layer.claim_client_keys(destination, content, timeout) - @defer.inlineCallbacks - @log_function - def backfill(self, dest, room_id, limit, extremities): - """Requests some more historic PDUs for the given context from the + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + ) -> Optional[List[EventBase]]: + """Requests some more historic PDUs for the given room from the given destination server. Args: - dest (str): The remote home server to ask. + dest (str): The remote homeserver to ask. room_id (str): The room_id to backfill. - limit (int): The maximum number of PDUs to return. - extremities (list): List of PDU id and origins of the first pdus - we have seen from the context - - Returns: - Deferred: Results in the received PDUs. + limit (int): The maximum number of events to return. + extremities (list): our current backwards extremities, to backfill from """ logger.debug("backfill extrem=%s", extremities) - # If there are no extremeties then we've (probably) reached the start. + # If there are no extremities then we've (probably) reached the start. if not extremities: - return + return None - transaction_data = yield self.transport_layer.backfill( + transaction_data = await self.transport_layer.backfill( dest, room_id, extremities, limit ) - logger.debug("backfill transaction_data=%s", repr(transaction_data)) + logger.debug("backfill transaction_data=%r", transaction_data) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) pdus = [ - event_from_pdu_json(p, format_ver, outlier=False) + event_from_pdu_json(p, room_version, outlier=False) for p in transaction_data["pdus"] ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield make_deferred_yieldable( + pdus[:] = await make_deferred_yieldable( defer.gatherResults( - self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True + self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True, ).addErrback(unwrapFirstError) ) return pdus - @defer.inlineCallbacks - @log_function - def get_pdu( - self, destinations, event_id, room_version, outlier=False, timeout=None - ): + async def get_pdu( + self, + destinations: Iterable[str], + event_id: str, + room_version: RoomVersion, + outlier: bool = False, + timeout: Optional[int] = None, + ) -> Optional[EventBase]: """Requests the PDU with given origin and ID from the remote home servers. @@ -227,18 +241,17 @@ class FederationClient(FederationBase): one succeeds. Args: - destinations (list): Which home servers to query - event_id (str): event to fetch - room_version (str): version of the room - outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if + destinations: Which homeservers to query + event_id: event to fetch + room_version: version of the room + outlier: Indicates whether the PDU is an `outlier`, i.e. if it's from an arbitary point in the context as opposed to part of the current block of PDUs. Defaults to `False` - timeout (int): How long to try (in ms) each destination for before + timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. Returns: - Deferred: Results in the requested PDU, or None if we were unable to find - it. + The requested PDU, or None if we were unable to find it. """ # TODO: Rate limit the number of times we try and get the same event. @@ -249,8 +262,6 @@ class FederationClient(FederationBase): pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) - format_ver = room_version_to_event_format(room_version) - signed_pdu = None for destination in destinations: now = self._clock.time_msec() @@ -259,7 +270,7 @@ class FederationClient(FederationBase): continue try: - transaction_data = yield self.transport_layer.get_event( + transaction_data = await self.transport_layer.get_event( destination, event_id, timeout=timeout ) @@ -271,15 +282,15 @@ class FederationClient(FederationBase): ) pdu_list = [ - event_from_pdu_json(p, format_ver, outlier=outlier) + event_from_pdu_json(p, room_version, outlier=outlier) for p in transaction_data["pdus"] - ] + ] # type: List[EventBase] if pdu_list and pdu_list[0]: pdu = pdu_list[0] # Check signatures are correct. - signed_pdu = yield self._check_sigs_and_hash(room_version, pdu) + signed_pdu = await self._check_sigs_and_hash(room_version, pdu) break @@ -309,177 +320,117 @@ class FederationClient(FederationBase): return signed_pdu - @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the room state at a given event from a remote home server. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + async def get_room_state_ids( + self, destination: str, room_id: str, event_id: str + ) -> Tuple[List[str], List[str]]: + """Calls the /state_ids endpoint to fetch the state at a particular point + in the room, and the auth events for the given event Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + a tuple of (state event_ids, auth event_ids) """ - try: - # First we try and ask for just the IDs, as thats far quicker if - # we have most of the state and auth_chain already. - # However, this may 404 if the other side has an old synapse. - result = yield self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - state_event_ids = result["pdu_ids"] - auth_event_ids = result.get("auth_chain_ids", []) - - fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( - destination, room_id, set(state_event_ids + auth_event_ids) - ) - - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - event_map = {ev.event_id: ev for ev in fetched_events} - - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [ - event_map[e_id] for e_id in auth_event_ids if e_id in event_map - ] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - except HttpResponseException as e: - if e.code == 400 or e.code == 404: - logger.info("Failed to use get_room_state_ids API, falling back") - else: - raise e - - result = yield self.transport_layer.get_room_state( + result = await self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id ) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) + state_event_ids = result["pdu_ids"] + auth_event_ids = result.get("auth_chain_ids", []) - pdus = [ - event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"] - ] + if not isinstance(state_event_ids, list) or not isinstance( + auth_event_ids, list + ): + raise Exception("invalid response from /state_ids") - auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in result.get("auth_chain", []) - ] - - seen_events = yield self.store.get_events( - [ev.event_id for ev in itertools.chain(pdus, auth_chain)] - ) - - signed_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, - [p for p in pdus if p.event_id not in seen_events], - outlier=True, - room_version=room_version, - ) - signed_pdus.extend( - seen_events[p.event_id] for p in pdus if p.event_id in seen_events - ) - - signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, - [p for p in auth_chain if p.event_id not in seen_events], - outlier=True, - room_version=room_version, - ) - signed_auth.extend( - seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events - ) - - signed_auth.sort(key=lambda e: e.depth) + return state_event_ids, auth_event_ids - return signed_pdus, signed_auth - - @defer.inlineCallbacks - def get_events_from_store_or_dest(self, destination, room_id, event_ids): - """Fetch events from a remote destination, checking if we already have them. + async def _check_sigs_and_hash_and_fetch( + self, + origin: str, + pdus: List[EventBase], + room_version: RoomVersion, + outlier: bool = False, + include_none: bool = False, + ) -> List[EventBase]: + """Takes a list of PDUs and checks the signatures and hashs of each + one. If a PDU fails its signature check then we check if we have it in + the database and if not then request if from the originating server of + that PDU. + + If a PDU fails its content hash check then it is redacted. + + The given list of PDUs are not modified, instead the function returns + a new list. Args: - destination (str) - room_id (str) - event_ids (list) + origin + pdu + room_version + outlier: Whether the events are outliers or not + include_none: Whether to include None in the returned list + for events that have failed their checks Returns: - Deferred: A deferred resolving to a 2-tuple where the first is a list of - events and the second is a list of event ids that we failed to fetch. + Deferred : A list of PDUs that have valid signatures and hashes. """ - seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = list(seen_events.values()) - - failed_to_fetch = set() + deferreds = self._check_sigs_and_hashes(room_version, pdus) - missing_events = set(event_ids) - for k in seen_events: - missing_events.discard(k) - - if not missing_events: - return signed_events, failed_to_fetch - - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, - ) - - room_version = yield self.store.get_room_version(room_id) + @defer.inlineCallbacks + def handle_check_result(pdu: EventBase, deferred: Deferred): + try: + res = yield make_deferred_yieldable(deferred) + except SynapseError: + res = None + + if not res: + # Check local db. + res = yield self.store.get_event( + pdu.event_id, allow_rejected=True, allow_none=True + ) - batch_size = 20 - missing_events = list(missing_events) - for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i : i + batch_size]) + if not res and pdu.origin != origin: + try: + res = yield defer.ensureDeferred( + self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + room_version=room_version, + outlier=outlier, + timeout=10000, + ) + ) + except SynapseError: + pass - deferreds = [ - run_in_background( - self.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, + if not res: + logger.warning( + "Failed to find copy of %s with valid signature", pdu.event_id ) - for e_id in batch - ] - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - for success, result in res: - if success and result: - signed_events.append(result) - batch.discard(result.event_id) + return res - # We removed all events we successfully fetched from `batch` - failed_to_fetch.update(batch) + handle = preserve_fn(handle_check_result) + deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] - return signed_events, failed_to_fetch + valid_pdus = await make_deferred_yieldable( + defer.gatherResults(deferreds2, consumeErrors=True) + ).addErrback(unwrapFirstError) - @defer.inlineCallbacks - @log_function - def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) + if include_none: + return valid_pdus + else: + return [p for p in valid_pdus if p] + + async def get_event_auth(self, destination, room_id, event_id): + res = await self.transport_layer.get_event_auth(destination, room_id, event_id) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] + event_from_pdu_json(p, room_version, outlier=True) + for p in res["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( destination, auth_chain, outlier=True, room_version=room_version ) @@ -487,17 +438,21 @@ class FederationClient(FederationBase): return signed_auth - @defer.inlineCallbacks - def _try_destination_list(self, description, destinations, callback): + async def _try_destination_list( + self, + description: str, + destinations: Iterable[str], + callback: Callable[[str], Awaitable[T]], + ) -> T: """Try an operation on a series of servers, until it succeeds Args: - description (unicode): description of the operation we're doing, for logging + description: description of the operation we're doing, for logging - destinations (Iterable[unicode]): list of server_names to try + destinations: list of server_names to try - callback (callable): Function to run for each server. Passed a single - argument: the server_name to try. May return a deferred. + callback: Function to run for each server. Passed a single + argument: the server_name to try. If the callback raises a CodeMessageException with a 300/400 code, attempts to perform the operation stop immediately and the exception is @@ -508,7 +463,7 @@ class FederationClient(FederationBase): suppressed if the exception is an InvalidResponseError. Returns: - The [Deferred] result of callback, if it succeeds + The result of callback, if it succeeds Raises: SynapseError if the chosen remote server returns a 300/400 code, or @@ -519,15 +474,17 @@ class FederationClient(FederationBase): continue try: - res = yield callback(destination) + res = await callback(destination) return res except InvalidResponseError as e: - logger.warn("Failed to %s via %s: %s", description, destination, e) + logger.warning("Failed to %s via %s: %s", description, destination, e) + except UnsupportedRoomVersionError: + raise except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: - logger.warn( + logger.warning( "Failed to %s via %s: %i %s", description, destination, @@ -535,13 +492,21 @@ class FederationClient(FederationBase): e.args[0], ) except Exception: - logger.warn("Failed to %s via %s", description, destination, exc_info=1) + logger.warning( + "Failed to %s via %s", description, destination, exc_info=True + ) raise SynapseError(502, "Failed to %s via any server" % (description,)) - def make_membership_event( - self, destinations, room_id, user_id, membership, content, params - ): + async def make_membership_event( + self, + destinations: Iterable[str], + room_id: str, + user_id: str, + membership: str, + content: dict, + params: Dict[str, str], + ) -> Tuple[str, EventBase, RoomVersion]: """ Creates an m.room.member event, with context, without participating in the room. @@ -553,26 +518,28 @@ class FederationClient(FederationBase): Note that this does not append any events to any graphs. Args: - destinations (str): Candidate homeservers which are probably + destinations: Candidate homeservers which are probably participating in the room. - room_id (str): The room in which the event will happen. - user_id (str): The user whose membership is being evented. - membership (str): The "membership" property of the event. Must be - one of "join" or "leave". - content (dict): Any additional data to put into the content field - of the event. - params (dict[str, str|Iterable[str]]): Query parameters to include in the - request. - Return: - Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of - `(origin, event, event_format)` where origin is the remote - homeserver which generated the event, and event_format is one of - `synapse.api.room_versions.EventFormatVersions`. - - Fails with a ``SynapseError`` if the chosen remote server - returns a 300/400 code. - - Fails with a ``RuntimeError`` if no servers were reachable. + room_id: The room in which the event will happen. + user_id: The user whose membership is being evented. + membership: The "membership" property of the event. Must be one of + "join" or "leave". + content: Any additional data to put into the content field of the + event. + params: Query parameters to include in the request. + + Returns: + `(origin, event, room_version)` where origin is the remote + homeserver which generated the event, and room_version is the + version of the room. + + Raises: + UnsupportedRoomVersionError: if remote responds with + a room version we don't understand. + + SynapseError: if the chosen remote server returns a 300/400 code. + + RuntimeError: if no servers were reachable. """ valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: @@ -581,16 +548,17 @@ class FederationClient(FederationBase): % (membership, ",".join(valid_memberships)) ) - @defer.inlineCallbacks - def send_request(destination): - ret = yield self.transport_layer.make_membership_event( + async def send_request(destination: str) -> Tuple[str, EventBase, RoomVersion]: + ret = await self.transport_layer.make_membership_event( destination, room_id, user_id, membership, params ) # Note: If not supplied, the room version may be either v1 or v2, # however either way the event format version will be v1. - room_version = ret.get("room_version", RoomVersions.V1.identifier) - event_format = room_version_to_event_format(room_version) + room_version_id = ret.get("room_version", RoomVersions.V1.identifier) + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + raise UnsupportedRoomVersionError() pdu_dict = ret.get("event", None) if not isinstance(pdu_dict, dict): @@ -610,94 +578,83 @@ class FederationClient(FederationBase): self._clock, self.hostname, self.signing_key, - format_version=event_format, + room_version=room_version, event_dict=pdu_dict, ) - return (destination, ev, event_format) + return destination, ev, room_version - return self._try_destination_list( + return await self._try_destination_list( "make_" + membership, destinations, send_request ) - def send_join(self, destinations, pdu, event_format_version): + async def send_join( + self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion + ) -> Dict[str, Any]: """Sends a join event to one of a list of homeservers. Doing so will cause the remote server to add the event to the graph, and send the event out to the rest of the federation. Args: - destinations (str): Candidate homeservers which are probably + destinations: Candidate homeservers which are probably participating in the room. - pdu (BaseEvent): event to be sent - event_format_version (int): The event format version + pdu: event to be sent + room_version: the version of the room (according to the server that + did the make_join) - Return: - Deferred: resolves to a dict with members ``origin`` (a string - giving the serer the event was sent to, ``state`` (?) and + Returns: + a dict with members ``origin`` (a string + giving the server the event was sent to, ``state`` (?) and ``auth_chain``. - Fails with a ``SynapseError`` if the chosen remote server - returns a 300/400 code. + Raises: + SynapseError: if the chosen remote server returns a 300/400 code. - Fails with a ``RuntimeError`` if no servers were reachable. + RuntimeError: if no servers were reachable. """ - def check_authchain_validity(signed_auth_chain): - for e in signed_auth_chain: - if e.type == EventTypes.Create: - create_event = e - break - else: - raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,)) - - # the room version should be sane. - room_version = create_event.content.get("room_version", "1") - if room_version not in KNOWN_ROOM_VERSIONS: - # This shouldn't be possible, because the remote server should have - # rejected the join attempt during make_join. - raise InvalidResponseError( - "room appears to have unsupported version %s" % (room_version,) - ) - - @defer.inlineCallbacks - def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + async def send_request(destination) -> Dict[str, Any]: + content = await self._do_send_join(destination, pdu) logger.debug("Got content: %s", content) state = [ - event_from_pdu_json(p, event_format_version, outlier=True) + event_from_pdu_json(p, room_version, outlier=True) for p in content.get("state", []) ] auth_chain = [ - event_from_pdu_json(p, event_format_version, outlier=True) + event_from_pdu_json(p, room_version, outlier=True) for p in content.get("auth_chain", []) ] pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} - room_version = None + create_event = None for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): - room_version = e.content.get( - "room_version", RoomVersions.V1.identifier - ) + create_event = e break - if room_version is None: + if create_event is None: # If the state doesn't have a create event then the room is # invalid, and it would fail auth checks anyway. raise SynapseError(400, "No create event in state") - valid_pdus = yield self._check_sigs_and_hash_and_fetch( + # the room version should be sane. + create_room_version = create_event.content.get( + "room_version", RoomVersions.V1.identifier + ) + if create_room_version != room_version.identifier: + # either the server that fulfilled the make_join, or the server that is + # handling the send_join, is lying. + raise InvalidResponseError( + "Unexpected room version %s in create event" + % (create_room_version,) + ) + + valid_pdus = await self._check_sigs_and_hash_and_fetch( destination, list(pdus.values()), outlier=True, @@ -725,7 +682,17 @@ class FederationClient(FederationBase): for s in signed_state: s.internal_metadata = copy.deepcopy(s.internal_metadata) - check_authchain_validity(signed_auth) + # double-check that the same create event has ended up in the auth chain + auth_chain_create_events = [ + e.event_id + for e in signed_auth + if (e.type, e.state_key) == (EventTypes.Create, "") + ] + if auth_chain_create_events != [create_event.event_id]: + raise InvalidResponseError( + "Unexpected create event(s) in auth chain: %s" + % (auth_chain_create_events,) + ) return { "state": signed_state, @@ -733,53 +700,84 @@ class FederationClient(FederationBase): "origin": destination, } - return self._try_destination_list("send_join", destinations, send_request) + return await self._try_destination_list("send_join", destinations, send_request) - @defer.inlineCallbacks - def send_invite(self, destination, room_id, event_id, pdu): - room_version = yield self.store.get_room_version(room_id) + async def _do_send_join(self, destination: str, pdu: EventBase): + time_now = self._clock.time_msec() + + try: + content = await self.transport_layer.send_join_v2( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() + + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() - content = yield self._do_send_invite(destination, pdu, room_version) + logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") + + resp = await self.transport_layer.send_join_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] + + async def send_invite( + self, destination: str, room_id: str, event_id: str, pdu: EventBase, + ) -> EventBase: + room_version = await self.store.get_room_version(room_id) + + content = await self._do_send_invite(destination, pdu, room_version) pdu_dict = content["event"] logger.debug("Got response to send_invite: %s", pdu_dict) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - - pdu = event_from_pdu_json(pdu_dict, format_ver) + pdu = event_from_pdu_json(pdu_dict, room_version) # Check signatures are correct. - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) # FIXME: We should handle signature failures more gracefully. return pdu - @defer.inlineCallbacks - def _do_send_invite(self, destination, pdu, room_version): + async def _do_send_invite( + self, destination: str, pdu: EventBase, room_version: RoomVersion + ) -> JsonDict: """Actually sends the invite, first trying v2 API and falling back to v1 API if necessary. - Args: - destination (str): Target server - pdu (FrozenEvent) - room_version (str) - Returns: - dict: The event as a dict as returned by the remote server + The event as a dict as returned by the remote server """ time_now = self._clock.time_msec() try: - content = yield self.transport_layer.send_invite_v2( + content = await self.transport_layer.send_invite_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content={ "event": pdu.get_pdu_json(time_now), - "room_version": room_version, + "room_version": room_version.identifier, "invite_room_state": pdu.unsigned.get("invite_room_state", []), }, ) @@ -797,8 +795,7 @@ class FederationClient(FederationBase): # Otherwise, we assume that the remote server doesn't understand # the v2 invite API. That's ok provided the room uses old-style event # IDs. - v = KNOWN_ROOM_VERSIONS.get(room_version) - if v.event_format != EventFormatVersions.V1: + if room_version.event_format != EventFormatVersions.V1: raise SynapseError( 400, "User's homeserver does not support this room version", @@ -812,7 +809,7 @@ class FederationClient(FederationBase): # Didn't work, try v1 API. # Note the v1 API returns a tuple of `(200, content)` - _, content = yield self.transport_layer.send_invite_v1( + _, content = await self.transport_layer.send_invite_v1( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -820,7 +817,7 @@ class FederationClient(FederationBase): ) return content - def send_leave(self, destinations, pdu): + async def send_leave(self, destinations: Iterable[str], pdu: EventBase) -> None: """Sends a leave event to one of a list of homeservers. Doing so will cause the remote server to add the event to the graph, @@ -829,48 +826,94 @@ class FederationClient(FederationBase): This is mostly useful to reject received invites. Args: - destinations (str): Candidate homeservers which are probably + destinations: Candidate homeservers which are probably participating in the room. - pdu (BaseEvent): event to be sent + pdu: event to be sent - Return: - Deferred: resolves to None. - - Fails with a ``SynapseError`` if the chosen remote server - returns a 300/400 code. + Raises: + SynapseError if the chosen remote server returns a 300/400 code. - Fails with a ``RuntimeError`` if no servers were reachable. + RuntimeError if no servers were reachable. """ - @defer.inlineCallbacks - def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( + async def send_request(destination: str) -> None: + content = await self._do_send_leave(destination, pdu) + logger.debug("Got content: %s", content) + + return await self._try_destination_list( + "send_leave", destinations, send_request + ) + + async def _do_send_leave(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = await self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - logger.debug("Got content: %s", content) - return None + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() - return self._try_destination_list("send_leave", destinations, send_request) + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") + + resp = await self.transport_layer.send_leave_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] def get_public_rooms( self, - destination, - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, + remote_server: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, ): - if destination == self.server_name: - return + """Get the list of public rooms from a remote homeserver + + Args: + remote_server: The name of the remote server + limit: Maximum amount of rooms to return + since_token: Used for result pagination + search_filter: A filter dictionary to send the remote homeserver + and filter the result set + include_all_networks: Whether to include results from all third party instances + third_party_instance_id: Whether to only include results from a specific third + party instance + + Returns: + Deferred[Dict[str, Any]]: The response from the remote server, or None if + `remote_server` is the same as the local server_name + + Raises: + HttpResponseException: There was an exception returned from the remote server + SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom + requests over federation + """ return self.transport_layer.get_public_rooms( - destination, + remote_server, limit, since_token, search_filter, @@ -878,72 +921,33 @@ class FederationClient(FederationBase): third_party_instance_id=third_party_instance_id, ) - @defer.inlineCallbacks - def query_auth(self, destination, room_id, event_id, local_auth): - """ - Params: - destination (str) - event_it (str) - local_auth (list) - """ - time_now = self._clock.time_msec() - - send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]} - - code, content = yield self.transport_layer.send_query_auth( - destination=destination, - room_id=room_id, - event_id=event_id, - content=send_content, - ) - - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - - auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]] - - signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version - ) - - signed_auth.sort(key=lambda e: e.depth) - - ret = { - "auth_chain": signed_auth, - "rejects": content.get("rejects", []), - "missing": content.get("missing", []), - } - - return ret - - @defer.inlineCallbacks - def get_missing_events( + async def get_missing_events( self, - destination, - room_id, - earliest_events_ids, - latest_events, - limit, - min_depth, - timeout, - ): + destination: str, + room_id: str, + earliest_events_ids: Sequence[str], + latest_events: Iterable[EventBase], + limit: int, + min_depth: int, + timeout: int, + ) -> List[EventBase]: """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. Args: - destination (str) - room_id (str) - earliest_events_ids (list): List of event ids. Effectively the + destination + room_id + earliest_events_ids: List of event ids. Effectively the events we expected to receive, but haven't. `get_missing_events` should only return events that didn't happen before these. - latest_events (list): List of events we have received that we don't + latest_events: List of events we have received that we don't have all previous events for. - limit (int): Maximum number of events to return. - min_depth (int): Minimum depth of events tor return. - timeout (int): Max time to wait in ms + limit: Maximum number of events to return. + min_depth: Minimum depth of events to return. + timeout: Max time to wait in ms """ try: - content = yield self.transport_layer.get_missing_events( + content = await self.transport_layer.get_missing_events( destination=destination, room_id=room_id, earliest_events=earliest_events_ids, @@ -953,14 +957,13 @@ class FederationClient(FederationBase): timeout=timeout, ) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) + room_version = await self.store.get_room_version(room_id) events = [ - event_from_pdu_json(e, format_ver) for e in content.get("events", []) + event_from_pdu_json(e, room_version) for e in content.get("events", []) ] - signed_events = yield self._check_sigs_and_hash_and_fetch( + signed_events = await self._check_sigs_and_hash_and_fetch( destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: @@ -973,14 +976,13 @@ class FederationClient(FederationBase): return signed_events - @defer.inlineCallbacks - def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: if destination == self.server_name: continue try: - yield self.transport_layer.exchange_third_party_invite( + await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) return None diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index da06ab379d..32a8a2ee46 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union import six from six import iteritems @@ -36,21 +38,24 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.crypto.event_signing import compute_event_signature -from synapse.events import room_version_to_event_format +from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name -from synapse.logging.context import nested_logging_context +from synapse.logging.context import ( + make_deferred_yieldable, + nested_logging_context, + run_in_background, +) from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace from synapse.logging.utils import log_function from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.types import get_domain_from_id -from synapse.util import glob_to_regex +from synapse.types import JsonDict, get_domain_from_id +from synapse.util import glob_to_regex, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -75,6 +80,9 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler + self.state = hs.get_state_handler() + + self.device_handler = hs.get_device_handler() self._server_linearizer = Linearizer("fed_server") self._transaction_linearizer = Linearizer("fed_txn_handler") @@ -87,14 +95,14 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request( + self, origin: str, room_id: str, versions: List[str], limit: int + ) -> Tuple[int, Dict[str, Any]]: + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -102,76 +110,114 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction( + self, origin: str, transaction_data: JsonDict + ) -> Tuple[int, Dict[str, Any]]: # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() transaction = Transaction(**transaction_data) - if not transaction.transaction_id: + if not transaction.transaction_id: # type: ignore raise Exception("Transaction missing transaction_id") - logger.debug("[%s] Got transaction", transaction.transaction_id) + logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( - (origin, transaction.transaction_id) + await self._transaction_linearizer.queue( + (origin, transaction.transaction_id) # type: ignore ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction( + self, origin: str, transaction: Transaction, request_time: int + ) -> Tuple[int, Dict[str, Any]]: """ Process an incoming transaction and return the HTTP response Args: - origin (unicode): the server making the request - transaction (Transaction): incoming transaction - request_time (int): timestamp that the HTTP request arrived at + origin: the server making the request + transaction: incoming transaction + request_time: timestamp that the HTTP request arrived at Returns: - Deferred[(int, object)]: http response code and body + HTTP response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id, + transaction.transaction_id, # type: ignore ) return response - logger.debug("[%s] Transaction is new", transaction.transaction_id) + logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore - # Reject if PDU count > 50 and EDU count > 100 - if len(transaction.pdus) > 50 or ( - hasattr(transaction, "edus") and len(transaction.edus) > 100 + # Reject if PDU count > 50 or EDU count > 100 + if len(transaction.pdus) > 50 or ( # type: ignore + hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore ): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response - received_pdus_counter.inc(len(transaction.pdus)) + # We process PDUs and EDUs in parallel. This is important as we don't + # want to block things like to device messages from reaching clients + # behind the potentially expensive handling of PDUs. + pdu_results, _ = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._handle_pdus_in_txn, origin, transaction, request_time + ), + run_in_background(self._handle_edus_in_txn, origin, transaction), + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + response = {"pdus": pdu_results} + + logger.debug("Returning: %s", str(response)) + + await self.transaction_actions.set_response(origin, transaction, 200, response) + return 200, response + + async def _handle_pdus_in_txn( + self, origin: str, transaction: Transaction, request_time: int + ) -> Dict[str, dict]: + """Process the PDUs in a received transaction. + + Args: + origin: the server making the request + transaction: incoming transaction + request_time: timestamp that the HTTP request arrived at + + Returns: + A map from event ID of a processed PDU to any errors we should + report back to the sending server. + """ + + received_pdus_counter.inc(len(transaction.pdus)) # type: ignore origin_host, _ = parse_server_name(origin) - pdus_by_room = {} + pdus_by_room = {} # type: Dict[str, List[EventBase]] - for p in transaction.pdus: + for p in transaction.pdus: # type: ignore if "unsigned" in p: unsigned = p["unsigned"] if "age" in unsigned: @@ -196,24 +242,17 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue - - try: - format_ver = room_version_to_event_format(room_version) - except UnsupportedRoomVersionError: + except UnsupportedRoomVersionError as e: # this can happen if support for a given room version is withdrawn, # so that we still get events for said room. - logger.info( - "Ignoring PDU for room %s with unknown version %s", - room_id, - room_version, - ) + logger.info("Ignoring PDU: %s", e) continue - event = event_from_pdu_json(p, format_ver) + event = event_from_pdu_json(p, room_version) pdus_by_room.setdefault(room_id, []).append(event) pdu_results = {} @@ -222,13 +261,12 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id: str): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn("Ignoring PDUs for room %s from banned server", room_id) + logger.warning("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -238,10 +276,10 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) + logger.warning("Error handling PDU %s: %s", event_id, e) pdu_results[event_id] = {"error": str(e)} except Exception as e: f = failure.Failure() @@ -252,36 +290,40 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) - if hasattr(transaction, "edus"): - for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) - - response = {"pdus": pdu_results} + return pdu_results - logger.debug("Returning: %s", str(response)) + async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): + """Process the EDUs in a received transaction. + """ - yield self.transaction_actions.set_response(origin, transaction, 200, response) - return 200, response + async def _process_edu(edu_dict): + received_edus_counter.inc() - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): - received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + edu = Edu( + origin=origin, + destination=self.server_name, + edu_type=edu_dict["edu_type"], + content=edu_dict["content"], + ) + await self.registry.on_edu(edu.edu_type, origin, edu.content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): - if not event_id: - raise NotImplementedError("Specify an event") + await concurrently_execute( + _process_edu, + getattr(transaction, "edus", []), + TRANSACTION_CONCURRENCY_LIMIT, + ) + async def on_context_state_request( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -290,237 +332,196 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( - (room_id, event_id), - self._on_context_state_request_compute, - room_id, - event_id, + with (await self._server_linearizer.queue((origin, room_id))): + resp = dict( + await self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, + event_id, + ) ) + room_version = await self.store.get_room_version_id(room_id) + resp["room_version"] = room_version + return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) - - for event in auth_chain: - # We sign these again because there was a bug where we - # incorrectly signed things the first time round - if self.hs.is_mine_id(event.event_id): - event.signatures.update( - compute_event_signature( - event.get_pdu_json(), - self.hs.hostname, - self.hs.config.signing_key[0], - ) - ) + async def _on_context_state_request_compute( + self, room_id: str, event_id: str + ) -> Dict[str, list]: + if event_id: + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + else: + pdus = (await self.state.get_current_state(room_id)).values() + + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request( + self, origin: str, event_id: str + ) -> Tuple[int, Union[JsonDict, str]]: + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request( + self, query_type: str, args: Dict[str, str] + ) -> Tuple[int, Dict[str, Any]]: received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request( + self, origin: str, room_id: str, user_id: str, supported_versions: List[str] + ) -> Dict[str, Any]: origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) if room_version not in supported_versions: - logger.warn("Room version %s not in %s", room_version, supported_versions) + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): - if room_version not in KNOWN_ROOM_VERSIONS: + async def on_invite_request( + self, origin: str, content: JsonDict, room_version_id: str + ) -> Dict[str, Any]: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: raise SynapseError( 400, "Homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - format_ver = room_version_to_event_format(room_version) - - pdu = event_from_pdu_json(content, format_ver) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request( + self, origin: str, content: JsonDict, room_id: str + ) -> Dict[str, Any]: logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - pdu = event_from_pdu_json(content, format_ver) + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - return ( - 200, - { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - }, - ) + return { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + } - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request( + self, origin: str, room_id: str, user_id: str + ) -> Dict[str, Any]: origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request( + self, origin: str, content: JsonDict, room_id: str + ) -> dict: logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - pdu = event_from_pdu_json(content, format_ver) + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - yield self.handler.on_send_leave_request(origin, pdu) - return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): - origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) - time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) - res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} - return 200, res + await self.handler.on_send_leave_request(origin, pdu) + return {} - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): - """ - Content is a dict with keys:: - auth_chain (list): A list of events that give the auth chain. - missing (list): A list of event_ids indicating what the other - side (`origin`) think we're missing. - rejects (dict): A mapping from event_id to a 2-tuple of reason - string and a proof (or None) of why the event was rejected. - The keys of this dict give the list of events the `origin` has - rejected. - - Args: - origin (str) - content (dict) - event_id (str) - - Returns: - Deferred: Results in `dict` with the same format as `content` - """ - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth( + self, origin: str, room_id: str, event_id: str + ) -> Tuple[int, Dict[str, Any]]: + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - - auth_chain = [ - event_from_pdu_json(e, format_ver) for e in content["auth_chain"] - ] - - signed_auth = yield self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True, room_version=room_version - ) - - ret = yield self.handler.on_query_auth( - origin, - event_id, - room_id, - signed_auth, - content.get("rejects", []), - content.get("missing", []), - ) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - send_content = { - "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]], - "rejects": ret.get("rejects", []), - "missing": ret.get("missing", []), - } - - return 200, send_content + auth_pdus = await self.handler.on_event_auth(event_id) + res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} + return 200, res @log_function - def on_query_client_keys(self, origin, content): - return self.on_query_request("client_keys", content) + async def on_query_client_keys( + self, origin: str, content: Dict[str, str] + ) -> Tuple[int, Dict[str, Any]]: + return await self.on_query_request("client_keys", content) - def on_query_user_devices(self, origin, user_id): - return self.on_query_request("user_devices", user_id) + async def on_query_user_devices( + self, origin: str, user_id: str + ) -> Tuple[int, Dict[str, Any]]: + keys = await self.device_handler.on_federation_query_user_devices(user_id) + return 200, keys @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys( + self, origin: str, content: JsonDict + ) -> Dict[str, Any]: query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) - json_result = {} + json_result = {} # type: Dict[str, Dict[str, dict]] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): @@ -542,16 +543,19 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( - self, origin, room_id, earliest_events, latest_events, limit - ): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_get_missing_events( + self, + origin: str, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> Dict[str, list]: + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - logger.info( + logger.debug( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d", earliest_events, @@ -559,27 +563,27 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) if len(missing_events) < 5: - logger.info( + logger.debug( "Returning %d events: %r", len(missing_events), missing_events ) else: - logger.info("Returning %d events", len(missing_events)) + logger.debug("Returning %d events", len(missing_events)) time_now = self._clock.time_msec() return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} @log_function - def on_openid_userinfo(self, token): + async def on_openid_userinfo(self, token: str) -> Optional[str]: ts_now_ms = self._clock.time_msec() - return self.store.get_user_id_for_open_id_token(token, ts_now_ms) + return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) - def _transaction_from_pdus(self, pdu_list): + def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: """Returns a new Transaction containing the given PDUs suitable for transmission. """ @@ -592,8 +596,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -614,10 +617,8 @@ class FederationServer(FederationBase): until we try to backfill across the discontinuity. Args: - origin (str): server which sent the pdu - pdu (FrozenEvent): received pdu - - Returns (Deferred): completes with None + origin: server which sent the pdu + pdu: received pdu Raises: FederationError if the signatures / hash do not match, or if the event was unacceptable for any other reason (eg, too large, @@ -646,68 +647,67 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( - self, sender_user_id, target_user_id, room_id, signed + async def exchange_third_party_invite( + self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request( + self, room_id: str, event_dict: Dict + ): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name: str, room_id: str): """Check if the given server is allowed by the server ACLs in the room Args: - server_name (str): name of server, *without any port part* - room_id (str): ID of the room to check + server_name: name of server, *without any port part* + room_id: ID of the room to check Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return raise AuthError(code=403, msg="Server is banned from room") -def server_matches_acl_event(server_name, acl_event): +def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: """Check if the given server is allowed by the ACL event Args: - server_name (str): name of server, without any port part - acl_event (EventBase): m.room.server_acl event + server_name: name of server, without any port part + acl_event: m.room.server_acl event Returns: - bool: True if this server is allowed by the ACLs + True if this server is allowed by the ACLs """ logger.debug("Checking %s against acl %s", server_name, acl_event.content) @@ -715,7 +715,7 @@ def server_matches_acl_event(server_name, acl_event): # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warn("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignorning non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -729,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event): # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warn("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignorning non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -739,7 +739,7 @@ def server_matches_acl_event(server_name, acl_event): # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warn("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignorning non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): @@ -751,9 +751,9 @@ def server_matches_acl_event(server_name, acl_event): return False -def _acl_entry_matches(server_name, acl_entry): +def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: if not isinstance(acl_entry, six.string_types): - logger.warn( + logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False @@ -770,13 +770,13 @@ class FederationHandlerRegistry(object): self.edu_handlers = {} self.query_handlers = {} - def register_edu_handler(self, edu_type, handler): + def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. Args: - edu_type (str): The type of the incoming EDU to register handler for - handler (Callable[[str, dict]]): A callable invoked on incoming EDU + edu_type: The type of the incoming EDU to register handler for + handler: A callable invoked on incoming EDU of the given type. The arguments are the origin server name and the EDU contents. """ @@ -787,14 +787,16 @@ class FederationHandlerRegistry(object): self.edu_handlers[edu_type] = handler - def register_query_handler(self, query_type, handler): + def register_query_handler( + self, query_type: str, handler: Callable[[dict], defer.Deferred] + ): """Sets the handler callable that will be used to handle an incoming federation query of the given type. Args: - query_type (str): Category name of the query, which should match + query_type: Category name of the query, which should match the string used by make_query. - handler (Callable[[dict], Deferred[dict]]): Invoked to handle + handler: Invoked to handle incoming queries of this type. The return will be yielded on and the result used as the response to the query request. """ @@ -805,24 +807,24 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type: str, origin: str, content: dict): handler = self.edu_handlers.get(edu_type) if not handler: - logger.warn("No handler registered for EDU type %s", edu_type) + logger.warning("No handler registered for EDU type %s", edu_type) + return with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: logger.exception("Failed to handle edu %r", edu_type) - def on_query(self, query_type, args): + def on_query(self, query_type: str, args: dict) -> defer.Deferred: handler = self.query_handlers.get(query_type) if not handler: - logger.warn("No handler registered for query type %s", query_type) + logger.warning("No handler registered for query type %s", query_type) raise NotFoundError("No handler for Query type '%s'" % (query_type,)) return handler(args) @@ -846,7 +848,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): super(ReplicationFederationHandlerRegistry, self).__init__() - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type: str, origin: str, content: dict): """Overrides FederationHandlerRegistry """ if not self.config.use_presence and edu_type == "m.presence": @@ -854,17 +856,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: - return super(ReplicationFederationHandlerRegistry, self).on_edu( + return await super(ReplicationFederationHandlerRegistry, self).on_edu( edu_type, origin, content ) - return self._send_edu(edu_type=edu_type, origin=origin, content=content) + return await self._send_edu(edu_type=edu_type, origin=origin, content=content) - def on_query(self, query_type, args): + async def on_query(self, query_type: str, args: dict): """Overrides FederationHandlerRegistry """ handler = self.query_handlers.get(query_type) if handler: - return handler(args) + return await handler(args) - return self._get_query_client(query_type=query_type, args=args) + return await self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 44edcabed4..d68b4bd670 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -44,7 +44,7 @@ class TransactionActions(object): response code and response body. """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " "transaction_id") + raise RuntimeError("Cannot persist a transaction with no transaction_id") return self.store.get_received_txn_response(transaction.transaction_id, origin) @@ -56,7 +56,7 @@ class TransactionActions(object): Deferred """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " "transaction_id") + raise RuntimeError("Cannot persist a transaction with no transaction_id") return self.store.set_received_txn_response( transaction.transaction_id, origin, code, response diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 454456a52d..52f4f54215 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -31,11 +31,14 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple +from typing import Dict, List, Tuple, Type from six import iteritems from sortedcontainers import SortedDict +from twisted.internet import defer + from synapse.metrics import LaterGauge from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure @@ -54,23 +57,35 @@ class FederationRemoteSendQueue(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.presence_map = {} # Pending presence map user_id -> UserPresenceState - self.presence_changed = SortedDict() # Stream position -> list[user_id] + # Pending presence map user_id -> UserPresenceState + self.presence_map = {} # type: Dict[str, UserPresenceState] + + # Stream position -> list[user_id] + self.presence_changed = SortedDict() # type: SortedDict[int, List[str]] # Stores the destinations we need to explicitly send presence to about a # given user. # Stream position -> (user_id, destinations) - self.presence_destinations = SortedDict() + self.presence_destinations = ( + SortedDict() + ) # type: SortedDict[int, Tuple[str, List[str]]] - self.keyed_edu = {} # (destination, key) -> EDU - self.keyed_edu_changed = SortedDict() # stream position -> (destination, key) + # (destination, key) -> EDU + self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] - self.edus = SortedDict() # stream position -> Edu + # stream position -> (destination, key) + self.keyed_edu_changed = ( + SortedDict() + ) # type: SortedDict[int, Tuple[str, tuple]] - self.device_messages = SortedDict() # stream position -> destination + self.edus = SortedDict() # type: SortedDict[int, Edu] + # stream ID for the next entry into presence_changed/keyed_edu_changed/edus. self.pos = 1 - self.pos_time = SortedDict() + + # map from stream ID to the time that stream entry was generated, so that we + # can clear out entries after a while + self.pos_time = SortedDict() # type: SortedDict[int, int] # EVERYTHING IS SAD. In particular, python only makes new scopes when # we make a new function, so we need to make a new function so the inner @@ -90,7 +105,6 @@ class FederationRemoteSendQueue(object): "keyed_edu", "keyed_edu_changed", "edus", - "device_messages", "pos_time", "presence_destinations", ]: @@ -130,9 +144,9 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.presence_changed[key] - user_ids = set( + user_ids = { user_id for uids in self.presence_changed.values() for user_id in uids - ) + } keys = self.presence_destinations.keys() i = self.presence_destinations.bisect_left(position_to_delete) @@ -159,8 +173,10 @@ class FederationRemoteSendQueue(object): for edu_key in self.keyed_edu_changed.values(): live_keys.add(edu_key) - to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys] - for edu_key in to_del: + keys_to_del = [ + edu_key for edu_key in self.keyed_edu if edu_key not in live_keys + ] + for edu_key in keys_to_del: del self.keyed_edu[edu_key] # Delete things out of edu map @@ -169,12 +185,6 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.edus[key] - # Delete things out of device map - keys = self.device_messages.keys() - i = self.device_messages.bisect_left(position_to_delete) - for key in keys[:i]: - del self.device_messages[key] - def notify_new_events(self, current_id): """As per FederationSender""" # We don't need to replicate this as it gets sent down a different @@ -212,7 +222,7 @@ class FederationRemoteSendQueue(object): receipt (synapse.types.ReadReceipt): """ # nothing to do here: the replication listener will handle it. - pass + return defer.succeed(None) def send_presence(self, states): """As per FederationSender @@ -247,9 +257,8 @@ class FederationRemoteSendQueue(object): def send_device_messages(self, destination): """As per FederationSender""" - pos = self._next_pos() - self.device_messages[pos] = destination - self.notifier.on_new_replication_data() + # We don't need to replicate this as it gets sent down a different + # stream. def get_current_token(self): return self.pos - 1 @@ -257,18 +266,24 @@ class FederationRemoteSendQueue(object): def federation_ack(self, token): self._clear_queue_before_pos(token) - def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): + async def get_replication_rows( + self, instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: """Get rows to be sent over federation between the two tokens Args: - from_token (int) - to_token(int) - limit (int) - federation_ack (int): Optional. The position where the worker is - explicitly acknowledged it has handled. Allows us to drop - data from before that point + instance_name: the name of the current process + from_token: the previous stream token: the starting point for fetching the + updates + to_token: the new stream token: the point to get updates up to + target_row_count: a target for the number of rows to be returned. + + Returns: a triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of `(token, row)` entries. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. """ - # TODO: Handle limit. + # TODO: Handle target_row_count. # To handle restarts where we wrap around if from_token > self.pos: @@ -276,12 +291,7 @@ class FederationRemoteSendQueue(object): # list of tuple(int, BaseFederationRow), where the first is the position # of the federation stream. - rows = [] - - # There should be only one reader, so lets delete everything its - # acknowledged its seen. - if federation_ack: - self._clear_queue_before_pos(federation_ack) + rows = [] # type: List[Tuple[int, BaseFederationRow]] # Fetch changed presence i = self.presence_changed.bisect_right(from_token) @@ -335,18 +345,14 @@ class FederationRemoteSendQueue(object): for (pos, edu) in edus: rows.append((pos, EduRow(edu))) - # Fetch changed device messages - i = self.device_messages.bisect_right(from_token) - j = self.device_messages.bisect_right(to_token) + 1 - device_messages = {v: k for k, v in self.device_messages.items()[i:j]} - - for (destination, pos) in iteritems(device_messages): - rows.append((pos, DeviceRow(destination=destination))) - # Sort rows based on pos rows.sort() - return [(pos, row.TypeId, row.to_data()) for pos, row in rows] + return ( + [(pos, (row.TypeId, row.to_data())) for pos, row in rows], + to_token, + False, + ) class BaseFederationRow(object): @@ -355,7 +361,7 @@ class BaseFederationRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + TypeId = "" # Unique string that ids the type. Must be overriden in sub classes. @staticmethod def from_data(data): @@ -468,29 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str - """Streams the fact that either a) there is pending to device messages for - users on the remote, or b) a local users device has changed and needs to - be sent to the remote. - """ +_rowtypes = ( + PresenceRow, + PresenceDestinationsRow, + KeyedEduRow, + EduRow, +) # type: Tuple[Type[BaseFederationRow], ...] - TypeId = "d" - - @staticmethod - def from_data(data): - return DeviceRow(destination=data["destination"]) - - def to_data(self): - return {"destination": self.destination} - - def add_to_buffer(self, buff): - buff.device_destinations.add(self.destination) - - -TypeToRow = { - Row.TypeId: Row - for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow) -} +TypeToRow = {Row.TypeId: Row for Row in _rowtypes} ParsedFederationStreamData = namedtuple( @@ -500,7 +491,6 @@ ParsedFederationStreamData = namedtuple( "presence_destinations", # list of tuples of UserPresenceState and destinations "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] - "device_destinations", # set of destinations ), ) @@ -511,7 +501,7 @@ def process_rows_for_federation(transaction_queue, rows): Args: transaction_queue (FederationSender) - rows (list(synapse.replication.tcp.streams.FederationStreamRow)) + rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow)) """ # The federation stream contains a bunch of different types of @@ -519,11 +509,7 @@ def process_rows_for_federation(transaction_queue, rows): # them into the appropriate collection and then send them off. buff = ParsedFederationStreamData( - presence=[], - presence_destinations=[], - keyed_edus={}, - edus={}, - device_destinations=set(), + presence=[], presence_destinations=[], keyed_edus={}, edus={}, ) # Parse the rows in the stream and add to the buffer @@ -551,6 +537,3 @@ def process_rows_for_federation(transaction_queue, rows): for destination, edu_list in iteritems(buff.edus): for edu in edu_list: transaction_queue.send_edu(edu, None) - - for destination in buff.device_destinations: - transaction_queue.send_device_messages(destination) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index d46f4aaeb1..d473576902 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple from six import itervalues @@ -21,7 +22,9 @@ from prometheus_client import Counter from twisted.internet import defer +import synapse import synapse.metrics +from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu @@ -38,7 +41,9 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.metrics import measure_func +from synapse.storage.presence import UserPresenceState +from synapse.types import ReadReceipt +from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -49,12 +54,12 @@ sent_pdus_destination_dist_count = Counter( sent_pdus_destination_dist_total = Counter( "synapse_federation_client_sent_pdu_destinations:total", - "" "Total number of PDUs queued for sending across all destinations", + "Total number of PDUs queued for sending across all destinations", ) class FederationSender(object): - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -67,7 +72,7 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", @@ -83,7 +88,7 @@ class FederationSender(object): # Map of user_id -> UserPresenceState for all the pending presence # to be sent out by user_id. Entries here get processed and put in # pending_presence_by_dest - self.pending_presence = {} + self.pending_presence = {} # type: Dict[str, UserPresenceState] LaterGauge( "synapse_federation_transaction_queue_pending_pdus", @@ -115,20 +120,17 @@ class FederationSender(object): # and that there is a pending call to _flush_rrs_for_room in the system. self._queues_awaiting_rr_flush_by_room = ( {} - ) # type: dict[str, set[PerDestinationQueue]] + ) # type: Dict[str, Set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( - 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second + 1000.0 / hs.config.federation_rr_transactions_per_room_per_second ) - def _get_per_destination_queue(self, destination): + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination Args: - destination (str): server_name of remote server - - Returns: - PerDestinationQueue + destination: server_name of remote server """ queue = self._per_destination_queues.get(destination) if not queue: @@ -136,7 +138,7 @@ class FederationSender(object): self._per_destination_queues[destination] = queue return queue - def notify_new_events(self, current_id): + def notify_new_events(self, current_id: int) -> None: """This gets called when we have some new events we might want to send out to other servers. """ @@ -150,13 +152,12 @@ class FederationSender(object): "process_event_queue_for_federation", self._process_event_queue_loop ) - @defer.inlineCallbacks - def _process_event_queue_loop(self): + async def _process_event_queue_loop(self) -> None: try: self._is_processing = True while True: - last_token = yield self.store.get_federation_out_pos("events") - next_token, events = yield self.store.get_all_new_events_stream( + last_token = await self.store.get_federation_out_pos("events") + next_token, events = await self.store.get_all_new_events_stream( last_token, self._last_poked_id, limit=100 ) @@ -165,8 +166,7 @@ class FederationSender(object): if not events and next_token >= self._last_poked_id: break - @defer.inlineCallbacks - def handle_event(event): + async def handle_event(event: EventBase) -> None: # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.sender) @@ -183,8 +183,8 @@ class FederationSender(object): # Otherwise if the last member on a server in a room is # banned then it won't receive the event because it won't # be in the room after the ban. - destinations = yield self.state.get_current_hosts_in_room( - event.room_id, latest_event_ids=event.prev_event_ids() + destinations = await self.state.get_hosts_in_room_at_events( + event.room_id, event_ids=event.prev_event_ids() ) except Exception: logger.exception( @@ -205,16 +205,16 @@ class FederationSender(object): self._send_pdu(event, destinations) - @defer.inlineCallbacks - def handle_room_events(events): - for event in events: - yield handle_event(event) + async def handle_room_events(events: Iterable[EventBase]) -> None: + with Measure(self.clock, "handle_room_events"): + for event in events: + await handle_event(event) - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(handle_room_events, evs) @@ -224,11 +224,11 @@ class FederationSender(object): ) ) - yield self.store.update_federation_out_pos("events", next_token) + await self.store.update_federation_out_pos("events", next_token) if events: now = self.clock.time_msec() - ts = yield self.store.get_received_ts(events[-1].event_id) + ts = await self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( "federation_sender" @@ -252,7 +252,7 @@ class FederationSender(object): finally: self._is_processing = False - def _send_pdu(self, pdu, destinations): + def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None: # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. @@ -274,11 +274,11 @@ class FederationSender(object): self._get_per_destination_queue(destination).send_pdu(pdu, order) @defer.inlineCallbacks - def send_read_receipt(self, receipt): + def send_read_receipt(self, receipt: ReadReceipt): """Send a RR to any other servers in the room Args: - receipt (synapse.types.ReadReceipt): receipt to be sent + receipt: receipt to be sent """ # Some background on the rate-limiting going on here. @@ -341,7 +341,7 @@ class FederationSender(object): else: queue.flush_read_receipts_for_room(room_id) - def _schedule_rr_flush_for_room(self, room_id, n_domains): + def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None: # that is going to cause approximately len(domains) transactions, so now back # off for that multiplied by RR_TXN_INTERVAL_PER_ROOM backoff_ms = self._rr_txn_interval_per_room_ms * n_domains @@ -350,7 +350,7 @@ class FederationSender(object): self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id) self._queues_awaiting_rr_flush_by_room[room_id] = set() - def _flush_rrs_for_room(self, room_id): + def _flush_rrs_for_room(self, room_id: str) -> None: queues = self._queues_awaiting_rr_flush_by_room.pop(room_id) logger.debug("Flushing RRs in %s to %s", room_id, queues) @@ -366,14 +366,11 @@ class FederationSender(object): @preserve_fn # the caller should not yield on this @defer.inlineCallbacks - def send_presence(self, states): + def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and triggers a background task to process them and send out the transactions. - - Args: - states (list(UserPresenceState)) """ if not self.hs.config.use_presence: # No-op if presence is disabled. @@ -410,11 +407,10 @@ class FederationSender(object): finally: self._processing_pending_presence = False - def send_presence_to_destinations(self, states, destinations): + def send_presence_to_destinations( + self, states: List[UserPresenceState], destinations: List[str] + ) -> None: """Send the given presence states to the given destinations. - - Args: - states (list[UserPresenceState]) destinations (list[str]) """ @@ -429,12 +425,9 @@ class FederationSender(object): @measure_func("txnqueue._process_presence") @defer.inlineCallbacks - def _process_presence_inner(self, states): + def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination - - Args: - states (list(UserPresenceState)) """ hosts_and_states = yield get_interested_remotes(self.store, states, self.state) @@ -444,14 +437,20 @@ class FederationSender(object): continue self._get_per_destination_queue(destination).send_presence(states) - def build_and_send_edu(self, destination, edu_type, content, key=None): + def build_and_send_edu( + self, + destination: str, + edu_type: str, + content: dict, + key: Optional[Hashable] = None, + ): """Construct an Edu object, and queue it for sending Args: - destination (str): name of server to send to - edu_type (str): type of EDU to send - content (dict): content of EDU - key (Any|None): clobbering key for this edu + destination: name of server to send to + edu_type: type of EDU to send + content: content of EDU + key: clobbering key for this edu """ if destination == self.server_name: logger.info("Not sending EDU to ourselves") @@ -466,12 +465,12 @@ class FederationSender(object): self.send_edu(edu, key) - def send_edu(self, edu, key): + def send_edu(self, edu: Edu, key: Optional[Hashable]): """Queue an EDU for sending Args: - edu (Edu): edu to send - key (Any|None): clobbering key for this edu + edu: edu to send + key: clobbering key for this edu """ queue = self._get_per_destination_queue(edu.destination) if key: @@ -479,12 +478,36 @@ class FederationSender(object): else: queue.send_edu(edu) - def send_device_messages(self, destination): + def send_device_messages(self, destination: str): if destination == self.server_name: - logger.info("Not sending device update to ourselves") + logger.warning("Not sending device update to ourselves") return self._get_per_destination_queue(destination).attempt_new_transaction() - def get_current_token(self): + def wake_destination(self, destination: str): + """Called when we want to retry sending transactions to a remote. + + This is mainly useful if the remote server has been down and we think it + might have come back. + """ + + if destination == self.server_name: + logger.warning("Not waking up ourselves") + return + + self._get_per_destination_queue(destination).attempt_new_transaction() + + @staticmethod + def get_current_token() -> int: + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. return 0 + + @staticmethod + async def get_replication_rows( + instance_name: str, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple[int, Tuple]], int, bool]: + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. + return [], 0, False diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index fad980b893..4e698981a4 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,11 +15,10 @@ # limitations under the License. import datetime import logging +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -30,9 +29,13 @@ from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import UserPresenceState +from synapse.storage.presence import UserPresenceState +from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter +if TYPE_CHECKING: + import synapse.server + # This is defined in the Matrix spec and enforced by the receiver. MAX_EDUS_PER_TRANSACTION = 100 @@ -55,13 +58,18 @@ class PerDestinationQueue(object): Manages the per-destination transmission queues. Args: - hs (synapse.HomeServer): - transaction_sender (TransactionManager): - destination (str): the server_name of the destination that we are managing + hs + transaction_sender + destination: the server_name of the destination that we are managing transmission for. """ - def __init__(self, hs, transaction_manager, destination): + def __init__( + self, + hs: "synapse.server.HomeServer", + transaction_manager: "synapse.federation.sender.TransactionManager", + destination: str, + ): self._server_name = hs.hostname self._clock = hs.get_clock() self._store = hs.get_datastore() @@ -71,20 +79,23 @@ class PerDestinationQueue(object): self.transmission_loop_running = False # a list of tuples of (pending pdu, order) - self._pending_pdus = [] # type: list[tuple[EventBase, int]] - self._pending_edus = [] # type: list[Edu] + self._pending_pdus = [] # type: List[Tuple[EventBase, int]] + + # XXX this is never actually used: see + # https://github.com/matrix-org/synapse/issues/7549 + self._pending_edus = [] # type: List[Edu] # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] + self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination - self._pending_presence = {} # type: dict[str, UserPresenceState] + self._pending_presence = {} # type: Dict[str, UserPresenceState] # room_id -> receipt_type -> user_id -> receipt_dict - self._pending_rrs = {} + self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. @@ -94,50 +105,50 @@ class PerDestinationQueue(object): # stream_id of last successfully sent device list update. self._last_device_list_stream_id = 0 - def __str__(self): + def __str__(self) -> str: return "PerDestinationQueue[%s]" % self._destination - def pending_pdu_count(self): + def pending_pdu_count(self) -> int: return len(self._pending_pdus) - def pending_edu_count(self): + def pending_edu_count(self) -> int: return ( len(self._pending_edus) + len(self._pending_presence) + len(self._pending_edus_keyed) ) - def send_pdu(self, pdu, order): + def send_pdu(self, pdu: EventBase, order: int) -> None: """Add a PDU to the queue, and start the transmission loop if neccessary Args: - pdu (EventBase): pdu to send - order (int): + pdu: pdu to send + order """ self._pending_pdus.append((pdu, order)) self.attempt_new_transaction() - def send_presence(self, states): + def send_presence(self, states: Iterable[UserPresenceState]) -> None: """Add presence updates to the queue. Start the transmission loop if neccessary. Args: - states (iterable[UserPresenceState]): presence to send + states: presence to send """ self._pending_presence.update({state.user_id: state for state in states}) self.attempt_new_transaction() - def queue_read_receipt(self, receipt): + def queue_read_receipt(self, receipt: ReadReceipt) -> None: """Add a RR to the list to be sent. Doesn't start the transmission loop yet (see flush_read_receipts_for_room) Args: - receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued + receipt: receipt to be queued """ self._pending_rrs.setdefault(receipt.room_id, {}).setdefault( receipt.receipt_type, {} )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data} - def flush_read_receipts_for_room(self, room_id): + def flush_read_receipts_for_room(self, room_id: str) -> None: # if we don't have any read-receipts for this room, it may be that we've already # sent them out, so we don't need to flush. if room_id not in self._pending_rrs: @@ -145,15 +156,15 @@ class PerDestinationQueue(object): self._rrs_pending_flush = True self.attempt_new_transaction() - def send_keyed_edu(self, edu, key): + def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu self.attempt_new_transaction() - def send_edu(self, edu): + def send_edu(self, edu) -> None: self._pending_edus.append(edu) self.attempt_new_transaction() - def attempt_new_transaction(self): + def attempt_new_transaction(self) -> None: """Try to start a new transaction to this destination If there is already a transaction in progress to this destination, @@ -176,31 +187,31 @@ class PerDestinationQueue(object): self._transaction_transmission_loop, ) - @defer.inlineCallbacks - def _transaction_transmission_loop(self): - pending_pdus = [] + async def _transaction_transmission_loop(self) -> None: + pending_pdus = [] # type: List[Tuple[EventBase, int]] try: self.transmission_loop_running = True # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. - yield get_retry_limiter(self._destination, self._clock, self._store) + await get_retry_limiter(self._destination, self._clock, self._store) pending_pdus = [] while True: # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = ( - yield self._get_device_update_edus(limit) + device_update_edus, dev_list_id = await self._get_device_update_edus( + limit ) limit -= len(device_update_edus) - to_device_edus, device_stream_id = ( - yield self._get_to_device_message_edus(limit) - ) + ( + to_device_edus, + device_stream_id, + ) = await self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus @@ -267,7 +278,7 @@ class PerDestinationQueue(object): # END CRITICAL SECTION - success = yield self._transaction_manager.send_new_transaction( + success = await self._transaction_manager.send_new_transaction( self._destination, pending_pdus, pending_edus ) if success: @@ -278,7 +289,7 @@ class PerDestinationQueue(object): # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages if to_device_edus: - yield self._store.delete_device_msgs_for_remote( + await self._store.delete_device_msgs_for_remote( self._destination, device_stream_id ) @@ -287,7 +298,7 @@ class PerDestinationQueue(object): logger.info( "Marking as sent %r %r", self._destination, dev_list_id ) - yield self._store.mark_as_sent_devices_by_remote( + await self._store.mark_as_sent_devices_by_remote( self._destination, dev_list_id ) @@ -332,7 +343,7 @@ class PerDestinationQueue(object): # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False - def _get_rr_edus(self, force_flush): + def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: return if not force_flush and not self._rrs_pending_flush: @@ -349,38 +360,36 @@ class PerDestinationQueue(object): self._rrs_pending_flush = False yield edu - def _pop_pending_edus(self, limit): + def _pop_pending_edus(self, limit: int) -> List[Edu]: pending_edus = self._pending_edus pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:] return pending_edus - @defer.inlineCallbacks - def _get_device_update_edus(self, limit): + async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination - now_stream_id, results = yield self._store.get_devices_by_remote( + now_stream_id, results = await self._store.get_device_updates_by_remote( self._destination, last_device_list, limit=limit ) edus = [ Edu( origin=self._server_name, destination=self._destination, - edu_type="m.device_list_update", + edu_type=edu_type, content=content, ) - for content in results + for (edu_type, content) in results ] - assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs" return (edus, now_stream_id) - @defer.inlineCallbacks - def _get_to_device_message_edus(self, limit): + async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]: last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() - contents, stream_id = yield self._store.get_new_device_msgs_for_remote( + contents, stream_id = await self._store.get_new_device_msgs_for_remote( self._destination, last_device_stream_id, to_device_stream_id, limit ) edus = [ diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5b6c79c51a..a2752a54a5 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, List from canonicaljson import json -from twisted.internet import defer - from synapse.api.errors import HttpResponseException +from synapse.events import EventBase from synapse.federation.persistence import TransactionActions -from synapse.federation.units import Transaction +from synapse.federation.units import Edu, Transaction from synapse.logging.opentracing import ( extract_text_map, set_tag, @@ -30,6 +30,9 @@ from synapse.logging.opentracing import ( ) from synapse.util.metrics import measure_func +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) @@ -39,7 +42,7 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self._server_name = hs.hostname self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() @@ -50,8 +53,9 @@ class TransactionManager(object): self._next_txn_id = int(self.clock.time_msec()) @measure_func("_send_new_transaction") - @defer.inlineCallbacks - def send_new_transaction(self, destination, pending_pdus, pending_edus): + async def send_new_transaction( + self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] + ): # Make a transaction-sending opentracing span. This span follows on from # all the edus in that transaction. This needs to be done since there is @@ -84,7 +88,7 @@ class TransactionManager(object): txn_id = str(self._next_txn_id) logger.debug( - "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)", + "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)", destination, txn_id, len(pdus), @@ -103,7 +107,7 @@ class TransactionManager(object): self._next_txn_id += 1 logger.info( - "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)", + "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)", destination, txn_id, transaction.transaction_id, @@ -127,7 +131,7 @@ class TransactionManager(object): return data try: - response = yield self._transport_layer.send_transaction( + response = await self._transport_layer.send_transaction( transaction, json_data_cb ) code = 200 @@ -146,7 +150,7 @@ class TransactionManager(object): if code == 200: for e_id, r in response.get("pdus", {}).items(): if "error" in r: - logger.warn( + logger.warning( "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, @@ -155,7 +159,7 @@ class TransactionManager(object): ) else: for p in pdus: - logger.warn( + logger.warning( "TX [%s] {%s} Failed to send event %s", destination, txn_id, diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py index d9fcc520a0..5db733af98 100644 --- a/synapse/federation/transport/__init__.py +++ b/synapse/federation/transport/__init__.py @@ -14,9 +14,9 @@ # limitations under the License. """The transport layer is responsible for both sending transactions to remote -home servers and receiving a variety of requests from other home servers. +homeservers and receiving a variety of requests from other homeservers. -By default this is done over HTTPS (and all home servers are required to +By default this is done over HTTPS (and all homeservers are required to support HTTPS), however individual pairings of servers may decide to communicate over a different (albeit still reliable) protocol. """ diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 482a101c09..060bf07197 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,12 +15,14 @@ # limitations under the License. import logging +from typing import Any, Dict, Optional from six.moves import urllib from twisted.internet import defer from synapse.api.constants import Membership +from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.urls import ( FEDERATION_UNSTABLE_PREFIX, FEDERATION_V1_PREFIX, @@ -39,36 +41,12 @@ class TransportLayerClient(object): self.client = hs.get_http_client() @log_function - def get_room_state(self, destination, room_id, event_id): - """ Requests all state for a given room from the given server at the - given event. - - Args: - destination (str): The host name of the remote home server we want - to get the state from. - context (str): The name of the context we want the state of - event_id (str): The event we want the context at. - - Returns: - Deferred: Results in a dict received from the remote homeserver. - """ - logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - - path = _create_v1_path("/state/%s", room_id) - return self.client.get_json( - destination, - path=path, - args={"event_id": event_id}, - try_trailing_slash_on_400=True, - ) - - @log_function def get_room_state_ids(self, destination, room_id, event_id): """ Requests all state for a given room from the given server at the given event. Returns the state's event_id's Args: - destination (str): The host name of the remote home server we want + destination (str): The host name of the remote homeserver we want to get the state from. context (str): The name of the context we want the state of event_id (str): The event we want the context at. @@ -91,7 +69,7 @@ class TransportLayerClient(object): """ Requests the pdu with give id and origin from the given server. Args: - destination (str): The host name of the remote home server we want + destination (str): The host name of the remote homeserver we want to get the state from. event_id (str): The id of the event being requested. timeout (int): How long to try (in ms) the destination for before @@ -122,10 +100,10 @@ class TransportLayerClient(object): Deferred: Results in a dict received from the remote homeserver. """ logger.debug( - "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", + "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", destination, room_id, - repr(event_tuples), + event_tuples, str(limit), ) @@ -267,7 +245,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_join(self, destination, room_id, event_id, content): + def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -278,7 +256,18 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_leave(self, destination, room_id, event_id, content): + def send_join_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, path=path, data=content + ) + + return response + + @defer.inlineCallbacks + @log_function + def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -296,6 +285,24 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def send_leave_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, + path=path, + data=content, + # we want to do our best to send this through. The problem is + # that if it fails, we won't retry it later, so if the remote + # server was just having a momentary blip, the room will be out of + # sync. + ignore_backoff=True, + ) + + return response + + @defer.inlineCallbacks + @log_function def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) @@ -320,18 +327,25 @@ class TransportLayerClient(object): @log_function def get_public_rooms( self, - remote_server, - limit, - since_token, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, + remote_server: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, ): + """Get the list of public rooms from a remote homeserver + + See synapse.federation.federation_client.FederationClient.get_public_rooms for + more information. + """ if search_filter: # this uses MSC2197 (Search Filtering over Federation) path = _create_v1_path("/publicRooms") - data = {"include_all_networks": "true" if include_all_networks else "false"} + data = { + "include_all_networks": "true" if include_all_networks else "false" + } # type: Dict[str, Any] if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -341,13 +355,25 @@ class TransportLayerClient(object): data["filter"] = search_filter - response = yield self.client.post_json( - destination=remote_server, path=path, data=data, ignore_backoff=True - ) + try: + response = yield self.client.post_json( + destination=remote_server, path=path, data=data, ignore_backoff=True + ) + except HttpResponseException as e: + if e.code == 403: + raise SynapseError( + 403, + "You are not allowed to view the public rooms list of %s" + % (remote_server,), + errcode=Codes.FORBIDDEN, + ) + raise else: path = _create_v1_path("/publicRooms") - args = {"include_all_networks": "true" if include_all_networks else "false"} + args = { + "include_all_networks": "true" if include_all_networks else "false" + } # type: Dict[str, Any] if third_party_instance_id: args["third_party_instance_id"] = (third_party_instance_id,) if limit: @@ -355,9 +381,19 @@ class TransportLayerClient(object): if since_token: args["since"] = [since_token] - response = yield self.client.get_json( - destination=remote_server, path=path, args=args, ignore_backoff=True - ) + try: + response = yield self.client.get_json( + destination=remote_server, path=path, args=args, ignore_backoff=True + ) + except HttpResponseException as e: + if e.code == 403: + raise SynapseError( + 403, + "You are not allowed to view the public rooms list of %s" + % (remote_server,), + errcode=Codes.FORBIDDEN, + ) + raise return response @@ -383,17 +419,6 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_query_auth(self, destination, room_id, event_id, content): - path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) - - content = yield self.client.post_json( - destination=destination, path=path, data=content - ) - - return content - - @defer.inlineCallbacks - @log_function def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -402,20 +427,30 @@ class TransportLayerClient(object): { "device_keys": { "<user_id>": ["<device_id>"] - } } + } + } Response: { "device_keys": { "<user_id>": { "<device_id>": {...} - } } } + } + }, + "master_key": { + "<user_id>": {...} + } + }, + "self_signing_key": { + "<user_id>": {...} + } + } Args: destination(str): The server to query. query_content(dict): The user ids to query. Returns: - A dict containg the device keys. + A dict containing device and cross-signing keys. """ path = _create_v1_path("/user/keys/query") @@ -432,14 +467,30 @@ class TransportLayerClient(object): Response: { "stream_id": "...", - "devices": [ { ... } ] + "devices": [ { ... } ], + "master_key": { + "user_id": "<user_id>", + "usage": [...], + "keys": {...}, + "signatures": { + "<user_id>": {...} + } + }, + "self_signing_key": { + "user_id": "<user_id>", + "usage": [...], + "keys": {...}, + "signatures": { + "<user_id>": {...} + } + } } Args: destination(str): The server to query. query_content(dict): The user ids to query. Returns: - A dict containg the device keys. + A dict containing device and cross-signing keys. """ path = _create_v1_path("/user/devices/%s", user_id) @@ -457,8 +508,10 @@ class TransportLayerClient(object): { "one_time_keys": { "<user_id>": { - "<device_id>": "<algorithm>" - } } } + "<device_id>": "<algorithm>" + } + } + } Response: { @@ -466,13 +519,16 @@ class TransportLayerClient(object): "<user_id>": { "<device_id>": { "<algorithm>:<key_id>": "<key_base64>" - } } } } + } + } + } + } Args: destination(str): The server to query. query_content(dict): The user ids to query. Returns: - A dict containg the one-time keys. + A dict containing the one-time keys. """ path = _create_v1_path("/user/keys/claim") diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 7dc696c7ae..af4595498c 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,6 +18,7 @@ import functools import logging import re +from typing import Optional, Tuple, Type from twisted.internet.defer import maybeDeferred @@ -44,6 +45,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string @@ -101,12 +103,17 @@ class NoAuthenticationError(AuthenticationError): class Authenticator(object): - def __init__(self, hs): + def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifer = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request(self, request, content): @@ -151,7 +158,7 @@ class Authenticator(object): origin, json_request, now, "Incoming request" ) - logger.info("Request from %s", origin) + logger.debug("Request from %s", origin) request.authenticated_entity = origin # If we get a valid signed request from the other side, its probably @@ -165,7 +172,18 @@ class Authenticator(object): async def _reset_retry_timings(self, origin): try: logger.info("Marking origin %r as up", origin) - await self.store.set_destination_retry_timings(origin, 0, 0) + await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifer.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + except Exception: logger.exception("Error resetting retry timings on %s", origin) @@ -202,7 +220,7 @@ def _parse_auth_header(header_bytes): sig = strip_quotes(param_dict["sig"]) return origin, key, sig except Exception as e: - logger.warn( + logger.warning( "Error parsing auth header '%s': %s", header_bytes.decode("ascii", "replace"), e, @@ -250,6 +268,8 @@ class BaseFederationServlet(object): returned. """ + PATH = "" # Overridden in subclasses, the regex to match against the path. + REQUIRE_AUTH = True PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version @@ -287,10 +307,12 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.warn("authenticate_request failed: missing authentication") + logger.warning( + "authenticate_request failed: missing authentication" + ) raise except Exception as e: - logger.warn("authenticate_request failed: %s", e) + logger.warning("authenticate_request failed: %s", e) raise request_tags = { @@ -328,9 +350,6 @@ class BaseFederationServlet(object): return response - # Extra logic that functools.wraps() doesn't finish - new_func.__self__ = func.__self__ - return new_func def register(self, server): @@ -419,7 +438,7 @@ class FederationEventServlet(BaseFederationServlet): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateServlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServlet): PATH = "/state/(?P<context>[^/]*)/?" # This is when someone asks for all data for a given context. @@ -427,7 +446,7 @@ class FederationStateServlet(BaseFederationServlet): return await self.handler.on_context_state_request( origin, context, - parse_string_from_args(query, "event_id", None, required=True), + parse_string_from_args(query, "event_id", None, required=False), ) @@ -504,11 +523,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet): return 200, content -class FederationSendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, content) + + +class FederationV2SendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, content @@ -519,13 +548,25 @@ class FederationEventAuthServlet(BaseFederationServlet): return await self.handler.on_event_auth(origin, context, event_id) -class FederationSendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServlet): PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content content = await self.handler.on_send_join_request(origin, content, context) + return 200, (200, content) + + +class FederationV2SendJoinServlet(BaseFederationServlet): + PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = await self.handler.on_send_join_request(origin, content, context) return 200, content @@ -538,7 +579,7 @@ class FederationV1InviteServlet(BaseFederationServlet): # state resolution algorithm, and we don't use that for processing # invites content = await self.handler.on_invite_request( - origin, content, room_version=RoomVersions.V1.identifier + origin, content, room_version_id=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` @@ -565,7 +606,7 @@ class FederationV2InviteServlet(BaseFederationServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state content = await self.handler.on_invite_request( - origin, event, room_version=room_version + origin, event, room_version_id=room_version ) return 200, content @@ -602,17 +643,6 @@ class FederationClientKeysClaimServlet(BaseFederationServlet): return 200, response -class FederationQueryAuthServlet(BaseFederationServlet): - PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" - - async def on_POST(self, origin, content, query, context, event_id): - new_content = await self.handler.on_query_auth_request( - origin, content, context, event_id - ) - - return 200, new_content - - class FederationGetMissingEventsServlet(BaseFederationServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P<room_id>[^/]*)/?" @@ -712,7 +742,7 @@ class PublicRoomList(BaseFederationServlet): This API returns information in the same format as /publicRooms on the client API, but will only ever include local public rooms and hence is - intended for consumption by other home servers. + intended for consumption by other homeservers. GET /publicRooms HTTP/1.1 @@ -765,6 +795,10 @@ class PublicRoomList(BaseFederationServlet): else: network_tuple = ThirdPartyInstanceID(None, None) + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + data = await maybeDeferred( self.handler.get_local_public_room_list, limit, @@ -779,7 +813,7 @@ class PublicRoomList(BaseFederationServlet): if not self.allow_access: raise FederationDeniedError(origin) - limit = int(content.get("limit", 100)) + limit = int(content.get("limit", 100)) # type: Optional[int] since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -800,6 +834,10 @@ class PublicRoomList(BaseFederationServlet): if search_filter is None: logger.warning("Nonefilter") + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + data = await self.handler.get_local_public_room_list( limit=limit, since_token=since_token, @@ -922,7 +960,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - result = await self.groups_handler.update_room_in_group( + result = await self.handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -1350,18 +1388,19 @@ class RoomComplexityServlet(BaseFederationServlet): FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationEventServlet, - FederationStateServlet, + FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, FederationQueryServlet, FederationMakeJoinServlet, FederationMakeLeaveServlet, FederationEventServlet, - FederationSendJoinServlet, - FederationSendLeaveServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, FederationV1InviteServlet, FederationV2InviteServlet, - FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, FederationClientKeysQueryServlet, @@ -1371,11 +1410,13 @@ FEDERATION_SERVLET_CLASSES = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, -) +) # type: Tuple[Type[BaseFederationServlet], ...] -OPENID_SERVLET_CLASSES = (OpenIdUserInfo,) +OPENID_SERVLET_CLASSES = ( + OpenIdUserInfo, +) # type: Tuple[Type[BaseFederationServlet], ...] -ROOM_LIST_CLASSES = (PublicRoomList,) +ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...] GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsProfileServlet, @@ -1396,17 +1437,19 @@ GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsAddRoomsServlet, FederationGroupsAddRoomsConfigServlet, FederationGroupsSettingJoinPolicyServlet, -) +) # type: Tuple[Type[BaseFederationServlet], ...] GROUP_LOCAL_SERVLET_CLASSES = ( FederationGroupsLocalInviteServlet, FederationGroupsRemoveLocalUserServlet, FederationGroupsBulkPublicisedServlet, -) +) # type: Tuple[Type[BaseFederationServlet], ...] -GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,) +GROUP_ATTESTATION_SERVLET_CLASSES = ( + FederationGroupsRenewAttestaionServlet, +) # type: Tuple[Type[BaseFederationServlet], ...] DEFAULT_SERVLET_GROUPS = ( "federation", diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b4d743cde7..6b32e0dcbf 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -19,11 +19,15 @@ server protocol. import logging +import attr + +from synapse.types import JsonDict from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) +@attr.s(slots=True) class Edu(JsonEncodedObject): """ An Edu represents a piece of data sent from one homeserver to another. @@ -32,11 +36,24 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = ["origin", "destination", "edu_type", "content"] + edu_type = attr.ib(type=str) + content = attr.ib(type=dict) + origin = attr.ib(type=str) + destination = attr.ib(type=str) - required_keys = ["edu_type"] + def get_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + } - internal_keys = ["origin", "destination"] + def get_internal_dict(self) -> JsonDict: + return { + "edu_type": self.edu_type, + "content": self.content, + "origin": self.origin, + "destination": self.destination, + } def get_context(self): return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dfd7ae041b..27b0c02655 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -37,13 +37,13 @@ An attestation is a signed blob of json that looks like: import logging import random +from typing import Tuple from signedjson.sign import sign_json from twisted.internet import defer from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id @@ -162,26 +162,26 @@ class GroupAttestionRenewer(object): def _start_renew_attestations(self): return run_as_background_process("renew_attestations", self._renew_attestations) - @defer.inlineCallbacks - def _renew_attestations(self): + async def _renew_attestations(self): """Called periodically to check if we need to update any of our attestations """ now = self.clock.time_msec() - rows = yield self.store.get_attestations_need_renewals( + rows = await self.store.get_attestations_need_renewals( now + UPDATE_ATTESTATION_TIME_MS ) @defer.inlineCallbacks - def _renew_attestation(group_id, user_id): + def _renew_attestation(group_user: Tuple[str, str]): + group_id, user_id = group_user try: if not self.is_mine_id(group_id): destination = get_domain_from_id(group_id) elif not self.is_mine_id(user_id): destination = get_domain_from_id(user_id) else: - logger.warn( + logger.warning( "Incorrectly trying to do attestations for user: %r in %r", user_id, group_id, @@ -208,7 +208,4 @@ class GroupAttestionRenewer(object): ) for row in rows: - group_id = row["group_id"] - user_id = row["user_id"] - - run_in_background(_renew_attestation, group_id, user_id) + await _renew_attestation((row["group_id"], row["user_id"])) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index d50e691436..8a9de913b3 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2017 Vector Creations Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com> # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,24 +19,22 @@ import logging from six import string_types -from twisted.internet import defer - -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute logger = logging.getLogger(__name__) -# TODO: Allow users to "knock" or simpkly join depending on rules +# TODO: Allow users to "knock" or simply join depending on rules # TODO: Federation admin APIs -# TODO: is_priveged flag to users and is_public to users and rooms +# TODO: is_privileged flag to users and is_public to users and rooms # TODO: Audit log for admins (profile updates, membership changes, users who tried # to join but were rejected, etc) # TODO: Flairs -class GroupsServerHandler(object): +class GroupsServerWorkerHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() @@ -50,11 +49,7 @@ class GroupsServerHandler(object): self.transport_client = hs.get_federation_transport_client() self.profile_handler = hs.get_profile_handler() - # Ensure attestations get renewed - hs.get_groups_attestation_renewer() - - @defer.inlineCallbacks - def check_group_is_ours( + async def check_group_is_ours( self, group_id, requester_user_id, and_exists=False, and_is_admin=None ): """Check that the group is ours, and optionally if it exists. @@ -70,25 +65,24 @@ class GroupsServerHandler(object): if not self.is_mine_id(group_id): raise SynapseError(400, "Group not on this server") - group = yield self.store.get_group(group_id) + group = await self.store.get_group(group_id) if and_exists and not group: raise SynapseError(404, "Unknown group") - is_user_in_group = yield self.store.is_user_in_group( + is_user_in_group = await self.store.is_user_in_group( requester_user_id, group_id ) if group and not is_user_in_group and not group["is_public"]: raise SynapseError(404, "Unknown group") if and_is_admin: - is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) + is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin) if not is_admin: raise SynapseError(403, "User is not admin in group") return group - @defer.inlineCallbacks - def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary(self, group_id, requester_user_id): """Get the summary for a group as seen by requester_user_id. The group summary consists of the profile of the room, and a curated @@ -97,28 +91,28 @@ class GroupsServerHandler(object): A user/room may appear in multiple roles/categories. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group( + is_user_in_group = await self.store.is_user_in_group( requester_user_id, group_id ) - profile = yield self.get_group_profile(group_id, requester_user_id) + profile = await self.get_group_profile(group_id, requester_user_id) - users, roles = yield self.store.get_users_for_summary_by_role( + users, roles = await self.store.get_users_for_summary_by_role( group_id, include_private=is_user_in_group ) # TODO: Add profiles to users - rooms, categories = yield self.store.get_rooms_for_summary_by_category( + rooms, categories = await self.store.get_rooms_for_summary_by_category( group_id, include_private=is_user_in_group ) for room_entry in rooms: room_id = room_entry["room_id"] - joined_users = yield self.store.get_users_in_room(room_id) - entry = yield self.room_list_handler.generate_room_entry( + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( room_id, len(joined_users), with_alias=False, allow_private=True ) entry = dict(entry) # so we don't change whats cached @@ -132,7 +126,7 @@ class GroupsServerHandler(object): user_id = entry["user_id"] if not self.is_mine_id(requester_user_id): - attestation = yield self.store.get_remote_attestation(group_id, user_id) + attestation = await self.store.get_remote_attestation(group_id, user_id) if not attestation: continue @@ -142,12 +136,12 @@ class GroupsServerHandler(object): group_id, user_id ) - user_profile = yield self.profile_handler.get_profile_from_cache(user_id) + user_profile = await self.profile_handler.get_profile_from_cache(user_id) entry.update(user_profile) users.sort(key=lambda e: e.get("order", 0)) - membership_info = yield self.store.get_users_membership_info_in_group( + membership_info = await self.store.get_users_membership_info_in_group( group_id, requester_user_id ) @@ -166,13 +160,195 @@ class GroupsServerHandler(object): "user": membership_info, } - @defer.inlineCallbacks - def update_group_summary_room( + async def get_group_categories(self, group_id, requester_user_id): + """Get all categories in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + categories = await self.store.get_group_categories(group_id=group_id) + return {"categories": categories} + + async def get_group_category(self, group_id, requester_user_id, category_id): + """Get a specific category in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = await self.store.get_group_category( + group_id=group_id, category_id=category_id + ) + + logger.info("group %s", res) + + return res + + async def get_group_roles(self, group_id, requester_user_id): + """Get all roles in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + roles = await self.store.get_group_roles(group_id=group_id) + return {"roles": roles} + + async def get_group_role(self, group_id, requester_user_id, role_id): + """Get a specific role in a group (as seen by user) + """ + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + res = await self.store.get_group_role(group_id=group_id, role_id=role_id) + return res + + async def get_group_profile(self, group_id, requester_user_id): + """Get the group profile as seen by requester_user_id + """ + + await self.check_group_is_ours(group_id, requester_user_id) + + group = await self.store.get_group(group_id) + + if group: + cols = [ + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", + ] + group_description = {key: group[key] for key in cols} + group_description["is_openly_joinable"] = group["join_policy"] == "open" + + return group_description + else: + raise SynapseError(404, "Unknown group") + + async def get_users_in_group(self, group_id, requester_user_id): + """Get the users in group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + user_results = await self.store.get_users_in_group( + group_id, include_private=is_user_in_group + ) + + chunk = [] + for user_result in user_results: + g_user_id = user_result["user_id"] + is_public = user_result["is_public"] + is_privileged = user_result["is_admin"] + + entry = {"user_id": g_user_id} + + profile = await self.profile_handler.get_profile_from_cache(g_user_id) + entry.update(profile) + + entry["is_public"] = bool(is_public) + entry["is_privileged"] = bool(is_privileged) + + if not self.is_mine_id(g_user_id): + attestation = await self.store.get_remote_attestation( + group_id, g_user_id + ) + if not attestation: + continue + + entry["attestation"] = attestation + else: + entry["attestation"] = self.attestations.create_attestation( + group_id, g_user_id + ) + + chunk.append(entry) + + # TODO: If admin add lists of users whose attestations have timed out + + return {"chunk": chunk, "total_user_count_estimate": len(user_results)} + + async def get_invited_users_in_group(self, group_id, requester_user_id): + """Get the users that have been invited to a group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + if not is_user_in_group: + raise SynapseError(403, "User not in group") + + invited_users = await self.store.get_invited_users_in_group(group_id) + + user_profiles = [] + + for user_id in invited_users: + user_profile = {"user_id": user_id} + try: + profile = await self.profile_handler.get_profile_from_cache(user_id) + user_profile.update(profile) + except Exception as e: + logger.warning("Error getting profile for %s: %s", user_id, e) + user_profiles.append(user_profile) + + return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} + + async def get_rooms_in_group(self, group_id, requester_user_id): + """Get the rooms in group as seen by requester_user_id + + This returns rooms in order of decreasing number of joined users + """ + + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_user_in_group = await self.store.is_user_in_group( + requester_user_id, group_id + ) + + room_results = await self.store.get_rooms_in_group( + group_id, include_private=is_user_in_group + ) + + chunk = [] + for room_result in room_results: + room_id = room_result["room_id"] + + joined_users = await self.store.get_users_in_room(room_id) + entry = await self.room_list_handler.generate_room_entry( + room_id, len(joined_users), with_alias=False, allow_private=True + ) + + if not entry: + continue + + entry["is_public"] = bool(room_result["is_public"]) + + chunk.append(entry) + + chunk.sort(key=lambda e: -e["num_joined_members"]) + + return {"chunk": chunk, "total_room_count_estimate": len(room_results)} + + +class GroupsServerHandler(GroupsServerWorkerHandler): + def __init__(self, hs): + super(GroupsServerHandler, self).__init__(hs) + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + async def update_group_summary_room( self, group_id, requester_user_id, room_id, category_id, content ): """Add/update a room to the group summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -182,7 +358,7 @@ class GroupsServerHandler(object): is_public = _parse_visibility_from_contents(content) - yield self.store.add_room_to_summary( + await self.store.add_room_to_summary( group_id=group_id, room_id=room_id, category_id=category_id, @@ -192,31 +368,29 @@ class GroupsServerHandler(object): return {} - @defer.inlineCallbacks - def delete_group_summary_room( + async def delete_group_summary_room( self, group_id, requester_user_id, room_id, category_id ): """Remove a room from the summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_room_from_summary( + await self.store.remove_room_from_summary( group_id=group_id, room_id=room_id, category_id=category_id ) return {} - @defer.inlineCallbacks - def set_group_join_policy(self, group_id, requester_user_id, content): + async def set_group_join_policy(self, group_id, requester_user_id, content): """Sets the group join policy. Currently supported policies are: - "invite": an invite must be received and accepted in order to join. - "open": anyone can join. """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -224,43 +398,23 @@ class GroupsServerHandler(object): if join_policy is None: raise SynapseError(400, "No value specified for 'm.join_policy'") - yield self.store.set_group_join_policy(group_id, join_policy=join_policy) + await self.store.set_group_join_policy(group_id, join_policy=join_policy) return {} - @defer.inlineCallbacks - def get_group_categories(self, group_id, requester_user_id): - """Get all categories in a group (as seen by user) - """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - categories = yield self.store.get_group_categories(group_id=group_id) - return {"categories": categories} - - @defer.inlineCallbacks - def get_group_category(self, group_id, requester_user_id, category_id): - """Get a specific category in a group (as seen by user) - """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - res = yield self.store.get_group_category( - group_id=group_id, category_id=category_id - ) - - return res - - @defer.inlineCallbacks - def update_group_category(self, group_id, requester_user_id, category_id, content): + async def update_group_category( + self, group_id, requester_user_id, category_id, content + ): """Add/Update a group category """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) profile = content.get("profile") - yield self.store.upsert_group_category( + await self.store.upsert_group_category( group_id=group_id, category_id=category_id, is_public=is_public, @@ -269,43 +423,23 @@ class GroupsServerHandler(object): return {} - @defer.inlineCallbacks - def delete_group_category(self, group_id, requester_user_id, category_id): + async def delete_group_category(self, group_id, requester_user_id, category_id): """Delete a group category """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_group_category( + await self.store.remove_group_category( group_id=group_id, category_id=category_id ) return {} - @defer.inlineCallbacks - def get_group_roles(self, group_id, requester_user_id): - """Get all roles in a group (as seen by user) - """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - roles = yield self.store.get_group_roles(group_id=group_id) - return {"roles": roles} - - @defer.inlineCallbacks - def get_group_role(self, group_id, requester_user_id, role_id): - """Get a specific role in a group (as seen by user) - """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - res = yield self.store.get_group_role(group_id=group_id, role_id=role_id) - return res - - @defer.inlineCallbacks - def update_group_role(self, group_id, requester_user_id, role_id, content): + async def update_group_role(self, group_id, requester_user_id, role_id, content): """Add/update a role in a group """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -313,31 +447,29 @@ class GroupsServerHandler(object): profile = content.get("profile") - yield self.store.upsert_group_role( + await self.store.upsert_group_role( group_id=group_id, role_id=role_id, is_public=is_public, profile=profile ) return {} - @defer.inlineCallbacks - def delete_group_role(self, group_id, requester_user_id, role_id): + async def delete_group_role(self, group_id, requester_user_id, role_id): """Remove role from group """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_group_role(group_id=group_id, role_id=role_id) + await self.store.remove_group_role(group_id=group_id, role_id=role_id) return {} - @defer.inlineCallbacks - def update_group_summary_user( + async def update_group_summary_user( self, group_id, requester_user_id, user_id, role_id, content ): """Add/update a users entry in the group summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -345,7 +477,7 @@ class GroupsServerHandler(object): is_public = _parse_visibility_from_contents(content) - yield self.store.add_user_to_summary( + await self.store.add_user_to_summary( group_id=group_id, user_id=user_id, role_id=role_id, @@ -355,49 +487,25 @@ class GroupsServerHandler(object): return {} - @defer.inlineCallbacks - def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id): + async def delete_group_summary_user( + self, group_id, requester_user_id, user_id, role_id + ): """Remove a user from the group summary """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_user_from_summary( + await self.store.remove_user_from_summary( group_id=group_id, user_id=user_id, role_id=role_id ) return {} - @defer.inlineCallbacks - def get_group_profile(self, group_id, requester_user_id): - """Get the group profile as seen by requester_user_id - """ - - yield self.check_group_is_ours(group_id, requester_user_id) - - group = yield self.store.get_group(group_id) - - if group: - cols = [ - "name", - "short_description", - "long_description", - "avatar_url", - "is_public", - ] - group_description = {key: group[key] for key in cols} - group_description["is_openly_joinable"] = group["join_policy"] == "open" - - return group_description - else: - raise SynapseError(404, "Unknown group") - - @defer.inlineCallbacks - def update_group_profile(self, group_id, requester_user_id, content): + async def update_group_profile(self, group_id, requester_user_id, content): """Update the group profile """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -409,158 +517,38 @@ class GroupsServerHandler(object): raise SynapseError(400, "%r value is not a string" % (keyname,)) profile[keyname] = value - yield self.store.update_group_profile(group_id, profile) - - @defer.inlineCallbacks - def get_users_in_group(self, group_id, requester_user_id): - """Get the users in group as seen by requester_user_id. + await self.store.update_group_profile(group_id, profile) - The ordering is arbitrary at the moment - """ - - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = yield self.store.is_user_in_group( - requester_user_id, group_id - ) - - user_results = yield self.store.get_users_in_group( - group_id, include_private=is_user_in_group - ) - - chunk = [] - for user_result in user_results: - g_user_id = user_result["user_id"] - is_public = user_result["is_public"] - is_privileged = user_result["is_admin"] - - entry = {"user_id": g_user_id} - - profile = yield self.profile_handler.get_profile_from_cache(g_user_id) - entry.update(profile) - - entry["is_public"] = bool(is_public) - entry["is_privileged"] = bool(is_privileged) - - if not self.is_mine_id(g_user_id): - attestation = yield self.store.get_remote_attestation( - group_id, g_user_id - ) - if not attestation: - continue - - entry["attestation"] = attestation - else: - entry["attestation"] = self.attestations.create_attestation( - group_id, g_user_id - ) - - chunk.append(entry) - - # TODO: If admin add lists of users whose attestations have timed out - - return {"chunk": chunk, "total_user_count_estimate": len(user_results)} - - @defer.inlineCallbacks - def get_invited_users_in_group(self, group_id, requester_user_id): - """Get the users that have been invited to a group as seen by requester_user_id. - - The ordering is arbitrary at the moment - """ - - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = yield self.store.is_user_in_group( - requester_user_id, group_id - ) - - if not is_user_in_group: - raise SynapseError(403, "User not in group") - - invited_users = yield self.store.get_invited_users_in_group(group_id) - - user_profiles = [] - - for user_id in invited_users: - user_profile = {"user_id": user_id} - try: - profile = yield self.profile_handler.get_profile_from_cache(user_id) - user_profile.update(profile) - except Exception as e: - logger.warn("Error getting profile for %s: %s", user_id, e) - user_profiles.append(user_profile) - - return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} - - @defer.inlineCallbacks - def get_rooms_in_group(self, group_id, requester_user_id): - """Get the rooms in group as seen by requester_user_id - - This returns rooms in order of decreasing number of joined users - """ - - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = yield self.store.is_user_in_group( - requester_user_id, group_id - ) - - room_results = yield self.store.get_rooms_in_group( - group_id, include_private=is_user_in_group - ) - - chunk = [] - for room_result in room_results: - room_id = room_result["room_id"] - - joined_users = yield self.store.get_users_in_room(room_id) - entry = yield self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True - ) - - if not entry: - continue - - entry["is_public"] = bool(room_result["is_public"]) - - chunk.append(entry) - - chunk.sort(key=lambda e: -e["num_joined_members"]) - - return {"chunk": chunk, "total_room_count_estimate": len(room_results)} - - @defer.inlineCallbacks - def add_room_to_group(self, group_id, requester_user_id, room_id, content): + async def add_room_to_group(self, group_id, requester_user_id, room_id, content): """Add room to group """ RoomID.from_string(room_id) # Ensure valid room id - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) - yield self.store.add_room_to_group(group_id, room_id, is_public=is_public) + await self.store.add_room_to_group(group_id, room_id, is_public=is_public) return {} - @defer.inlineCallbacks - def update_room_in_group( + async def update_room_in_group( self, group_id, requester_user_id, room_id, config_key, content ): """Update room in group """ RoomID.from_string(room_id) # Ensure valid room id - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) if config_key == "m.visibility": is_public = _parse_visibility_dict(content) - yield self.store.update_room_in_group_visibility( + await self.store.update_room_in_group_visibility( group_id, room_id, is_public=is_public ) else: @@ -568,29 +556,38 @@ class GroupsServerHandler(object): return {} - @defer.inlineCallbacks - def remove_room_from_group(self, group_id, requester_user_id, room_id): + async def remove_room_from_group(self, group_id, requester_user_id, room_id): """Remove room from group """ - yield self.check_group_is_ours( + await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_room_from_group(group_id, room_id) + await self.store.remove_room_from_group(group_id, room_id) return {} - @defer.inlineCallbacks - def invite_to_group(self, group_id, user_id, requester_user_id, content): + async def invite_to_group(self, group_id, user_id, requester_user_id, content): """Invite user to group """ - group = yield self.check_group_is_ours( + group = await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) # TODO: Check if user knocked - # TODO: Check if user is already invited + + invited_users = await self.store.get_invited_users_in_group(group_id) + if user_id in invited_users: + raise SynapseError( + 400, "User already invited to group", errcode=Codes.BAD_STATE + ) + + user_results = await self.store.get_users_in_group( + group_id, include_private=True + ) + if user_id in (user_result["user_id"] for user_result in user_results): + raise SynapseError(400, "User already in group") content = { "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, @@ -599,18 +596,18 @@ class GroupsServerHandler(object): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - res = yield groups_local.on_invite(group_id, user_id, content) + res = await groups_local.on_invite(group_id, user_id, content) local_attestation = None else: local_attestation = self.attestations.create_attestation(group_id, user_id) content.update({"attestation": local_attestation}) - res = yield self.transport_client.invite_to_group_notification( + res = await self.transport_client.invite_to_group_notification( get_domain_from_id(user_id), group_id, user_id, content ) user_profile = res.get("user_profile", {}) - yield self.store.add_remote_profile_cache( + await self.store.add_remote_profile_cache( user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), @@ -620,13 +617,13 @@ class GroupsServerHandler(object): if not self.hs.is_mine_id(user_id): remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=user_id, group_id=group_id ) else: remote_attestation = None - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, user_id, is_admin=False, @@ -635,15 +632,14 @@ class GroupsServerHandler(object): remote_attestation=remote_attestation, ) elif res["state"] == "invite": - yield self.store.add_group_invite(group_id, user_id) + await self.store.add_group_invite(group_id, user_id) return {"state": "invite"} elif res["state"] == "reject": return {"state": "reject"} else: raise SynapseError(502, "Unknown state returned by HS") - @defer.inlineCallbacks - def _add_user(self, group_id, user_id, content): + async def _add_user(self, group_id, user_id, content): """Add a user to a group based on a content dict. See accept_invite, join_group. @@ -653,7 +649,7 @@ class GroupsServerHandler(object): remote_attestation = content["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=user_id, group_id=group_id ) else: @@ -662,7 +658,7 @@ class GroupsServerHandler(object): is_public = _parse_visibility_from_contents(content) - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, user_id, is_admin=False, @@ -673,73 +669,70 @@ class GroupsServerHandler(object): return local_attestation - @defer.inlineCallbacks - def accept_invite(self, group_id, requester_user_id, content): + async def accept_invite(self, group_id, requester_user_id, content): """User tries to accept an invite to the group. This is different from them asking to join, and so should error if no invite exists (and they're not a member of the group) """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_invited = yield self.store.is_user_invited_to_local_group( + is_invited = await self.store.is_user_invited_to_local_group( group_id, requester_user_id ) if not is_invited: raise SynapseError(403, "User not invited to group") - local_attestation = yield self._add_user(group_id, requester_user_id, content) + local_attestation = await self._add_user(group_id, requester_user_id, content) return {"state": "join", "attestation": local_attestation} - @defer.inlineCallbacks - def join_group(self, group_id, requester_user_id, content): + async def join_group(self, group_id, requester_user_id, content): """User tries to join the group. This will error if the group requires an invite/knock to join """ - group_info = yield self.check_group_is_ours( + group_info = await self.check_group_is_ours( group_id, requester_user_id, and_exists=True ) if group_info["join_policy"] != "open": raise SynapseError(403, "Group is not publicly joinable") - local_attestation = yield self._add_user(group_id, requester_user_id, content) + local_attestation = await self._add_user(group_id, requester_user_id, content) return {"state": "join", "attestation": local_attestation} - @defer.inlineCallbacks - def knock(self, group_id, requester_user_id, content): + async def knock(self, group_id, requester_user_id, content): """A user requests becoming a member of the group """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() - @defer.inlineCallbacks - def accept_knock(self, group_id, requester_user_id, content): + async def accept_knock(self, group_id, requester_user_id, content): """Accept a users knock to the room. Errors if the user hasn't knocked, rather than inviting them. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) raise NotImplementedError() - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from the group; either a user is leaving or an admin kicked them. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_kick = False if requester_user_id != user_id: - is_admin = yield self.store.is_user_admin_in_group( + is_admin = await self.store.is_user_admin_in_group( group_id, requester_user_id ) if not is_admin: @@ -747,25 +740,29 @@ class GroupsServerHandler(object): is_kick = True - yield self.store.remove_user_from_group(group_id, user_id) + await self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) if not self.hs.is_mine_id(user_id): - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) + + # Delete group if the last user has left + users = await self.store.get_users_in_group(group_id, include_private=True) + if not users: + await self.store.delete_group(group_id) return {} - @defer.inlineCallbacks - def create_group(self, group_id, requester_user_id, content): - group = yield self.check_group_is_ours(group_id, requester_user_id) + async def create_group(self, group_id, requester_user_id, content): + group = await self.check_group_is_ours(group_id, requester_user_id) logger.info("Attempting to create group with ID: %r", group_id) @@ -775,7 +772,7 @@ class GroupsServerHandler(object): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) if not is_admin: @@ -798,7 +795,7 @@ class GroupsServerHandler(object): long_description = profile.get("long_description") user_profile = content.get("user_profile", {}) - yield self.store.create_group( + await self.store.create_group( group_id, requester_user_id, name=name, @@ -810,7 +807,7 @@ class GroupsServerHandler(object): if not self.hs.is_mine_id(requester_user_id): remote_attestation = content["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=requester_user_id, group_id=group_id ) @@ -821,7 +818,7 @@ class GroupsServerHandler(object): local_attestation = None remote_attestation = None - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, requester_user_id, is_admin=True, @@ -831,7 +828,7 @@ class GroupsServerHandler(object): ) if not self.hs.is_mine_id(requester_user_id): - yield self.store.add_remote_profile_cache( + await self.store.add_remote_profile_cache( requester_user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), @@ -839,8 +836,7 @@ class GroupsServerHandler(object): return {"group_id": group_id} - @defer.inlineCallbacks - def delete_group(self, group_id, requester_user_id): + async def delete_group(self, group_id, requester_user_id): """Deletes a group, kicking out all current members. Only group admins or server admins can call this request @@ -849,18 +845,16 @@ class GroupsServerHandler(object): group_id (str) request_user_id (str) - Returns: - Deferred """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) + is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) @@ -868,18 +862,17 @@ class GroupsServerHandler(object): raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) - @defer.inlineCallbacks - def _kick_user_from_group(user_id): + async def _kick_user_from_group(user_id): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # We kick users out in the order of: # 1. Non-admins @@ -898,11 +891,11 @@ class GroupsServerHandler(object): else: non_admins.append(u["user_id"]) - yield concurrently_execute(_kick_user_from_group, non_admins, 10) - yield concurrently_execute(_kick_user_from_group, admins, 10) - yield _kick_user_from_group(requester_user_id) + await concurrently_execute(_kick_user_from_group, non_admins, 10) + await concurrently_execute(_kick_user_from_group, admins, 10) + await _kick_user_from_group(requester_user_id) - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) def _parse_join_policy_from_contents(content): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d15c6282fb..61dc4beafe 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError +from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID logger = logging.getLogger(__name__) @@ -44,11 +44,26 @@ class BaseHandler(object): self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() - self.ratelimiter = hs.get_ratelimiter() - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() self.clock = hs.get_clock() self.hs = hs + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = self.hs.config.rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if self.hs.config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_admin_redaction.per_second, + burst_count=self.hs.config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() @@ -70,7 +85,6 @@ class BaseHandler(object): Raises: LimitExceededError if the request should be ratelimited """ - time_now = self.clock.time() user_id = requester.user.to_string() # The AS user itself is never rate limited. @@ -83,73 +97,55 @@ class BaseHandler(object): if requester.app_service and not requester.app_service.is_rate_limited(): return + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + # Check if there is a per user override in the DB. override = yield self.store.get_ratelimit_for_user(user_id) if override: - # If overriden with a null Hz then ratelimiting has been entirely + # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user if not override.messages_per_second: return messages_per_second = override.messages_per_second burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) else: - # We default to different values if this is an admin redaction and - # the config is set - if is_admin_redaction and self.hs.config.rc_admin_redaction: - messages_per_second = self.hs.config.rc_admin_redaction.per_second - burst_count = self.hs.config.rc_admin_redaction.burst_count - else: - messages_per_second = self.hs.config.rc_message.per_second - burst_count = self.hs.config.rc_message.burst_count - - if is_admin_redaction and self.hs.config.rc_admin_redaction: - # If we have separate config for admin redactions we use a separate - # ratelimiter - allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( - user_id, - time_now, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) - else: - allowed, time_allowed = self.ratelimiter.can_do_action( + # Override rate and burst count per-user + self.request_ratelimiter.ratelimit( user_id, - time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, context=None): + async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids(self.store) - current_state = yield self.store.get_events( + current_state_ids = await context.get_current_state_ids() + current_state = await self.store.get_events( list(current_state_ids.values()) ) else: - current_state = yield self.state_handler.get_current_state( + current_state = await self.state_handler.get_current_state( event.room_id ) current_state = list(current_state.values()) logger.info("maybe_kick_guest_users %r", current_state) - yield self.kick_guest_users(current_state) + await self.kick_guest_users(current_state) - @defer.inlineCallbacks - def kick_guest_users(self, current_state): + async def kick_guest_users(self, current_state): for member_event in current_state: try: if member_event.type != EventTypes.Member: @@ -180,7 +176,7 @@ class BaseHandler(object): # homeserver. requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() - yield handler.update_membership( + await handler.update_membership( requester, target_user, member_event.room_id, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 38bc67191c..a8d3fbc6de 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - class AccountDataEventSource(object): def __init__(self, hs): @@ -23,24 +21,24 @@ class AccountDataEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() - @defer.inlineCallbacks - def get_new_events(self, user, from_key, **kwargs): + async def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_account_data_stream_id() + current_stream_id = self.store.get_max_account_data_stream_id() results = [] - tags = yield self.store.get_updated_tags(user_id, last_stream_id) + tags = await self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): results.append( {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} ) - account_data, room_account_data = ( - yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) - ) + ( + account_data, + room_account_data, + ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) @@ -52,7 +50,3 @@ class AccountDataEventSource(object): ) return results, current_stream_id - - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): - return [], config.to_id diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d04e0fe576..590135d19c 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,8 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText - -from twisted.internet import defer +from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -45,7 +44,11 @@ class AccountValidityHandler(object): self._account_validity = self.hs.config.account_validity - if self._account_validity.renew_by_email_enabled and load_jinja2_templates: + if ( + self._account_validity.enabled + and self._account_validity.renew_by_email_enabled + and load_jinja2_templates + ): # Don't do email-specific configuration if renewal by email is disabled. try: app_name = self.hs.config.email_app_name @@ -78,42 +81,39 @@ class AccountValidityHandler(object): # run as a background process to make sure that the database transactions # have a logcontext to report to return run_as_background_process( - "send_renewals", self.send_renewal_emails + "send_renewals", self._send_renewal_emails ) self.clock.looping_call(send_emails, 30 * 60 * 1000) - @defer.inlineCallbacks - def send_renewal_emails(self): + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they have an email 3PID attached to their account. """ - expiring_users = yield self.store.get_users_expiring_soon() + expiring_users = await self.store.get_users_expiring_soon() if expiring_users: for user in expiring_users: - yield self._send_renewal_email( + await self._send_renewal_email( user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - @defer.inlineCallbacks - def send_renewal_email_to_user(self, user_id): - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - yield self._send_renewal_email(user_id, expiration_ts) + async def send_renewal_email_to_user(self, user_id: str): + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + await self._send_renewal_email(user_id, expiration_ts) - @defer.inlineCallbacks - def _send_renewal_email(self, user_id, expiration_ts): + async def _send_renewal_email(self, user_id: str, expiration_ts: int): """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. Args: - user_id (str): ID of the user to send email(s) to. - expiration_ts (int): Timestamp in milliseconds for the expiration date of + user_id: ID of the user to send email(s) to. + expiration_ts: Timestamp in milliseconds for the expiration date of this user's account (used in the email templates). """ - addresses = yield self._get_email_addresses_for_user(user_id) + addresses = await self._get_email_addresses_for_user(user_id) # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their @@ -125,7 +125,7 @@ class AccountValidityHandler(object): return try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -133,7 +133,7 @@ class AccountValidityHandler(object): except StoreError: user_display_name = user_id - renewal_token = yield self._get_renewal_token(user_id) + renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( self.hs.config.public_baseurl, renewal_token, @@ -165,7 +165,7 @@ class AccountValidityHandler(object): logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, self._raw_from, @@ -180,19 +180,18 @@ class AccountValidityHandler(object): ) ) - yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) + await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - @defer.inlineCallbacks - def _get_email_addresses_for_user(self, user_id): + async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: """Retrieve the list of email addresses attached to a user's account. Args: - user_id (str): ID of the user to lookup email addresses for. + user_id: ID of the user to lookup email addresses for. Returns: - defer.Deferred[list[str]]: Email addresses for this account. + Email addresses for this account. """ - threepids = yield self.store.user_get_threepids(user_id) + threepids = await self.store.user_get_threepids(user_id) addresses = [] for threepid in threepids: @@ -201,16 +200,15 @@ class AccountValidityHandler(object): return addresses - @defer.inlineCallbacks - def _get_renewal_token(self, user_id): + async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. Args: - user_id (str): ID of the user to generate a string for. + user_id: ID of the user to generate a string for. Returns: - defer.Deferred[str]: The generated string. + The generated string. Raises: StoreError(500): Couldn't generate a unique string after 5 attempts. @@ -219,52 +217,52 @@ class AccountValidityHandler(object): while attempts < 5: try: renewal_token = stringutils.random_string(32) - yield self.store.set_renewal_token_for_user(user_id, renewal_token) + await self.store.set_renewal_token_for_user(user_id, renewal_token) return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - @defer.inlineCallbacks - def renew_account(self, renewal_token): + async def renew_account(self, renewal_token: str) -> bool: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. + renewal_token: Token sent with the renewal request. Returns: - bool: Whether the provided token is valid. + Whether the provided token is valid. """ try: - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + user_id = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - defer.returnValue(False) + return False logger.debug("Renewing an account for user %s", user_id) - yield self.renew_account_for_user(user_id) + await self.renew_account_for_user(user_id) - defer.returnValue(True) + return True - @defer.inlineCallbacks - def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + async def renew_account_for_user( + self, user_id: str, expiration_ts: int = None, email_sent: bool = False + ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. - expiration_ts (int): New expiration date. Defaults to now + validity period. - email_sent (bool): Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. + expiration_ts: New expiration date. Defaults to now + validity period. + email_sen: Whether an email has been sent for this validity period. Defaults to False. Returns: - defer.Deferred[int]: New expiration date for this account, as a timestamp - in milliseconds since epoch. + New expiration date for this account, as a timestamp in + milliseconds since epoch. """ if expiration_ts is None: expiration_ts = self.clock.time_msec() + self._account_validity.period - yield self.store.set_account_validity_for_user( + await self.store.set_account_validity_for_user( user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 46ac73106d..a2d7959abe 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -25,6 +25,15 @@ from synapse.app import check_bind_error logger = logging.getLogger(__name__) +ACME_REGISTER_FAIL_ERROR = """ +-------------------------------------------------------------------------------- +Failed to register with the ACME provider. This is likely happening because the installation +is new, and ACME v1 has been deprecated by Let's Encrypt and disabled for +new installations since November 2019. +At the moment, Synapse doesn't support ACME v2. For more information and alternative +solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1 +--------------------------------------------------------------------------------""" + class AcmeHandler(object): def __init__(self, hs): @@ -71,7 +80,12 @@ class AcmeHandler(object): # want it to control where we save the certificates, we have to reach in # and trigger the registration machinery ourselves. self._issuer._registered = False - yield self._issuer._ensure_registered() + + try: + yield self._issuer._ensure_registered() + except Exception: + logger.error(ACME_REGISTER_FAIL_ERROR) + raise @defer.inlineCallbacks def provision_certificate(self): diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1a87b58838..f3c0aeceb6 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,11 +14,11 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import List from synapse.api.constants import Membership -from synapse.types import RoomStreamToken +from synapse.events import FrozenEvent +from synapse.types import RoomStreamToken, StateMap from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -30,11 +30,13 @@ class AdminHandler(BaseHandler): def __init__(self, hs): super(AdminHandler, self).__init__(hs) - @defer.inlineCallbacks - def get_whois(self, user): + self.storage = hs.get_storage() + self.state_store = self.storage.state + + async def get_whois(self, user): connections = [] - sessions = yield self.store.get_user_ip_and_agents(user) + sessions = await self.store.get_user_ip_and_agents(user) for session in sessions: connections.append( { @@ -51,70 +53,18 @@ class AdminHandler(BaseHandler): return ret - @defer.inlineCallbacks - def get_users(self): - """Function to reterive a list of users in users table. - - Args: - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - ret = yield self.store.get_users() - - return ret - - @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. - - Args: - order (str): column name to order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive - Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} - """ - ret = yield self.store.get_users_paginate(order, start, limit) - + async def get_user(self, user): + """Function to get user details""" + ret = await self.store.get_user_by_id(user.to_string()) + if ret: + profile = await self.store.get_profileinfo(user.localpart) + threepids = await self.store.user_get_threepids(user.to_string()) + ret["displayname"] = profile.display_name + ret["avatar_url"] = profile.avatar_url + ret["threepids"] = threepids return ret - @defer.inlineCallbacks - def search_users(self, term): - """Function to search users list for one or more users with - the matched term. - - Args: - term (str): search term - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - ret = yield self.store.search_users(term) - - return ret - - def get_user_server_admin(self, user): - """ - Get the admin bit on a user. - - Args: - user_id (UserID): the (necessarily local) user to manipulate - """ - return self.store.is_server_admin(user) - - def set_user_server_admin(self, user, admin): - """ - Set the admin bit on a user. - - Args: - user_id (UserID): the (necessarily local) user to manipulate - admin (bool): whether or not the user should be an admin of this server - """ - return self.store.set_server_admin(user, admin) - - @defer.inlineCallbacks - def export_user_data(self, user_id, writer): + async def export_user_data(self, user_id, writer): """Write all data we have on the user to the given writer. Args: @@ -126,7 +76,7 @@ class AdminHandler(BaseHandler): The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_local_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -139,7 +89,7 @@ class AdminHandler(BaseHandler): # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle # those seperately. - rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id) + rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): room_id = room.room_id @@ -148,7 +98,7 @@ class AdminHandler(BaseHandler): "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) ) - forgotten = yield self.store.did_forget(user_id, room_id) + forgotten = await self.store.did_forget(user_id, room_id) if forgotten: logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue @@ -160,7 +110,7 @@ class AdminHandler(BaseHandler): if room.membership == Membership.INVITE: event_id = room.event_id - invite = yield self.store.get_event(event_id, allow_none=True) + invite = await self.store.get_event(event_id, allow_none=True) if invite: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) @@ -171,7 +121,7 @@ class AdminHandler(BaseHandler): # were joined. We estimate that point by looking at the # stream_ordering of the last membership if it wasn't a join. if room.membership == Membership.JOIN: - stream_ordering = yield self.store.get_room_max_stream_ordering() + stream_ordering = self.store.get_room_max_stream_ordering() else: stream_ordering = room.stream_ordering @@ -197,7 +147,7 @@ class AdminHandler(BaseHandler): # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = yield self.store.paginate_room_events( + events, _ = await self.store.paginate_room_events( room_id, from_key, to_key, limit=100, direction="f" ) if not events: @@ -205,7 +155,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.store, user_id, events) + events = await filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -241,7 +191,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.store.get_state_for_event(event_id) + state = await self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() @@ -251,35 +201,26 @@ class ExfiltrationWriter(object): """Interface used to specify how to write exported data. """ - def write_events(self, room_id, events): + def write_events(self, room_id: str, events: List[FrozenEvent]): """Write a batch of events for a room. - - Args: - room_id (str) - events (list[FrozenEvent]) """ pass - def write_state(self, room_id, event_id, state): + def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]): """Write the state at the given event in the room. This only gets called for backward extremities rather than for each event. - - Args: - room_id (str) - event_id (str) - state (dict[tuple[str, str], FrozenEvent]) """ pass - def write_invite(self, room_id, event, state): + def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]): """Write an invite for the room, with associated invite state. Args: - room_id (str) - event (FrozenEvent) - state (dict[tuple[str, str], dict]): A subset of the state at the + room_id + event + state: A subset of the state at the invite, with a subset of the event keys (type, state_key content and sender) """ diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3e9b298154..fe62f78e67 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -73,7 +73,10 @@ class ApplicationServicesHandler(object): try: limit = 100 while True: - upper_bound, events = yield self.store.get_new_events_for_appservice( + ( + upper_bound, + events, + ) = yield self.store.get_new_events_for_appservice( self.current_max, limit ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 374372b69e..119678e67b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -17,14 +17,12 @@ import logging import time import unicodedata +import urllib.parse +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import attr -import bcrypt +import bcrypt # type: ignore[import] import pymacaroons -from canonicaljson import json - -from twisted.internet import defer -from twisted.web.client import PartialDownloadError import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType @@ -38,11 +36,15 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.config.emailconfig import ThreepidBehaviour +from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.http.server import finish_request +from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.types import UserID -from synapse.util.caches.expiringcache import ExpiringCache +from synapse.push.mailer import load_jinja2_templates +from synapse.types import Requester, UserID from ._base import BaseHandler @@ -58,23 +60,14 @@ class AuthHandler(BaseHandler): hs (synapse.server.HomeServer): """ super(AuthHandler, self).__init__(hs) - self.checkers = { - LoginType.RECAPTCHA: self._check_recaptcha, - LoginType.EMAIL_IDENTITY: self._check_email_identity, - LoginType.MSISDN: self._check_msisdn, - LoginType.DUMMY: self._check_dummy_auth, - LoginType.TERMS: self._check_terms_auth, - } - self.bcrypt_rounds = hs.config.bcrypt_rounds - # This is not a cache per se, but a store of all current sessions that - # expire after N hours - self.sessions = ExpiringCache( - cache_name="register_sessions", - clock=hs.get_clock(), - expiry_ms=self.SESSION_EXPIRE_MS, - reset_expiry_on_get=True, - ) + self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] + for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: + inst = auth_checker_class(hs) + if inst.is_enabled(): + self.checkers[inst.AUTH_TYPE] = inst # type: ignore + + self.bcrypt_rounds = hs.config.bcrypt_rounds account_handler = ModuleApi(hs, self) self.password_providers = [ @@ -87,6 +80,9 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled + self._sso_enabled = ( + hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled + ) # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -102,14 +98,70 @@ class AuthHandler(BaseHandler): if t not in login_types: login_types.append(t) self._supported_login_types = login_types - - self._account_ratelimiter = Ratelimiter() - self._failed_attempts_ratelimiter = Ratelimiter() + # Login types and UI Auth types have a heavy overlap, but are not + # necessarily identical. Login types have SSO (and other login types) + # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. + ui_auth_types = login_types.copy() + if self._sso_enabled: + ui_auth_types.append(LoginType.SSO) + self._supported_ui_auth_types = ui_auth_types + + # Ratelimiter for failed auth during UIA. Uses same ratelimit config + # as per `rc_login.failed_attempts`. + self._failed_uia_attempts_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) self._clock = self.hs.get_clock() - @defer.inlineCallbacks - def validate_user_via_ui_auth(self, requester, request_body, clientip): + # Expire old UI auth sessions after a period of time. + if hs.config.worker_app is None: + self._clock.looping_call( + run_as_background_process, + 5 * 60 * 1000, + "expire_old_sessions", + self._expire_old_sessions, + ) + + # Load the SSO HTML templates. + + # The following template is shown to the user during a client login via SSO, + # after the SSO completes and before redirecting them back to their client. + # It notifies the user they are about to give access to their matrix account + # to the client. + self._sso_redirect_confirm_template = load_jinja2_templates( + hs.config.sso_template_dir, ["sso_redirect_confirm.html"], + )[0] + # The following template is shown during user interactive authentication + # in the fallback auth scenario. It notifies the user that they are + # authenticating for an operation to occur on their account. + self._sso_auth_confirm_template = load_jinja2_templates( + hs.config.sso_template_dir, ["sso_auth_confirm.html"], + )[0] + # The following template is shown after a successful user interactive + # authentication session. It tells the user they can close the window. + self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if + # the account is deactivated. + self._sso_account_deactivated_template = ( + hs.config.sso_account_deactivated_template + ) + + self._server_name = hs.config.server_name + + # cast to tuple for use with str.startswith + self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + + async def validate_user_via_ui_auth( + self, + requester: Requester, + request: SynapseRequest, + request_body: Dict[str, Any], + clientip: str, + description: str, + ) -> dict: """ Checks that the user is who they claim to be, via a UI auth. @@ -118,14 +170,19 @@ class AuthHandler(BaseHandler): that it isn't stolen by re-authenticating them. Args: - requester (Requester): The user, as given by the access token + requester: The user, as given by the access token + + request: The request sent by the client. - request_body (dict): The body of the request sent by the client + request_body: The body of the request sent by the client - clientip (str): The IP address of the client. + clientip: The IP address of the client. + + description: A human readable string to be displayed to the user that + describes the operation happening on their account. Returns: - defer.Deferred[dict]: the parameters for this request (which may + The parameters for this request (which may have been given only in a previous call). Raises: @@ -134,15 +191,31 @@ class AuthHandler(BaseHandler): AuthError if the client has completed a login flow, and it gives a different user to `requester` + + LimitExceededError if the ratelimiter's failed request count for this + user is too high to proceed + """ + user_id = requester.user.to_string() + + # Check if we should be ratelimited due to too many previous failed attempts + self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) + # build a list of supported flows - flows = [[login_type] for login_type in self._supported_login_types] + flows = [[login_type] for login_type in self._supported_ui_auth_types] - result, params, _ = yield self.check_auth(flows, request_body, clientip) + try: + result, params, _ = await self.check_auth( + flows, request, request_body, clientip, description + ) + except LoginError: + # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). + self._failed_uia_attempts_ratelimiter.can_do_action(user_id) + raise # find the completed login type - for login_type in self._supported_login_types: + for login_type in self._supported_ui_auth_types: if login_type not in result: continue @@ -158,34 +231,48 @@ class AuthHandler(BaseHandler): return params - @defer.inlineCallbacks - def check_auth(self, flows, clientdict, clientip): + def get_enabled_auth_types(self): + """Return the enabled user-interactive authentication types + + Returns the UI-Auth types which are supported by the homeserver's current + config. + """ + return self.checkers.keys() + + async def check_auth( + self, + flows: List[List[str]], + request: SynapseRequest, + clientdict: Dict[str, Any], + clientip: str, + description: str, + ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. - As a side effect, this function fills in the 'creds' key on the user's - session with a map, which maps each auth-type (str) to the relevant - identity authenticated by that auth-type (mostly str, but for captcha, bool). - If no auth flows have been completed successfully, raises an InteractiveAuthIncompleteError. To handle this, you can use synapse.rest.client.v2_alpha._base.interactive_auth_handler as a decorator. Args: - flows (list): A list of login flows. Each flow is an ordered list of - strings representing auth-types. At least one full - flow must be completed in order for auth to be successful. + flows: A list of login flows. Each flow is an ordered list of + strings representing auth-types. At least one full + flow must be completed in order for auth to be successful. + + request: The request sent by the client. clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. - clientip (str): The IP address of the client. + clientip: The IP address of the client. + + description: A human readable string to be displayed to the user that + describes the operation happening on their account. Returns: - defer.Deferred[dict, dict, str]: a deferred tuple of - (creds, params, session_id). + A tuple of (creds, params, session_id). 'creds' contains the authenticated credentials of each stage. @@ -201,47 +288,93 @@ class AuthHandler(BaseHandler): """ authdict = None - sid = None + sid = None # type: Optional[str] if clientdict and "auth" in clientdict: authdict = clientdict["auth"] del clientdict["auth"] if "session" in authdict: sid = authdict["session"] - session = self._get_session_info(sid) - if len(clientdict) > 0: + # Convert the URI and method to strings. + uri = request.uri.decode("utf-8") + method = request.uri.decode("utf-8") + + # If there's no session ID, create a new session. + if not sid: + session = await self.store.create_ui_auth_session( + clientdict, uri, method, description + ) + + else: + try: + session = await self.store.get_ui_auth_session(sid) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (sid,)) + + # If the client provides parameters, update what is persisted, + # otherwise use whatever was last provided. + # # This was designed to allow the client to omit the parameters # and just supply the session in subsequent calls so it split # auth between devices by just sharing the session, (eg. so you # could continue registration from your phone having clicked the # email auth link on there). It's probably too open to abuse # because it lets unauthenticated clients store arbitrary objects - # on a home server. - # Revisit: Assumimg the REST APIs do sensible validation, the data - # isn't arbintrary. - session["clientdict"] = clientdict - self._save_session(session) - elif "clientdict" in session: - clientdict = session["clientdict"] + # on a homeserver. + # + # Revisit: Assuming the REST APIs do sensible validation, the data + # isn't arbitrary. + # + # Note that the registration endpoint explicitly removes the + # "initial_device_display_name" parameter if it is provided + # without a "password" parameter. See the changes to + # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST + # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9. + if not clientdict: + clientdict = session.clientdict + + # Ensure that the queried operation does not vary between stages of + # the UI authentication session. This is done by generating a stable + # comparator and storing it during the initial query. Subsequent + # queries ensure that this comparator has not changed. + # + # The comparator is based on the requested URI and HTTP method. The + # client dict (minus the auth dict) should also be checked, but some + # clients are not spec compliant, just warn for now if the client + # dict changes. + if (session.uri, session.method) != (uri, method): + raise SynapseError( + 403, + "Requested operation has changed during the UI authentication session.", + ) + + if session.clientdict != clientdict: + logger.warning( + "Requested operation has changed during the UI " + "authentication session. A future version of Synapse " + "will remove this capability." + ) + + # For backwards compatibility, changes to the client dict are + # persisted as clients modify them throughout their user interactive + # authentication flow. + await self.store.set_ui_auth_clientdict(sid, clientdict) if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session) + self._auth_dict_for_flows(flows, session.session_id) ) - if "creds" not in session: - session["creds"] = {} - creds = session["creds"] - # check auth type currently being presented - errordict = {} + errordict = {} # type: Dict[str, Any] if "type" in authdict: - login_type = authdict["type"] + login_type = authdict["type"] # type: str try: - result = yield self._check_auth_dict(authdict, clientip) + result = await self._check_auth_dict(authdict, clientip) if result: - creds[login_type] = result - self._save_session(session) + await self.store.mark_ui_auth_stage_complete( + session.session_id, login_type, result + ) except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new @@ -257,6 +390,7 @@ class AuthHandler(BaseHandler): # so that the client can have another go. errordict = e.error_dict() + creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can @@ -269,15 +403,17 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) - return creds, clientdict, session["id"] - ret = self._auth_dict_for_flows(flows, session) + return creds, clientdict, session.session_id + + ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) raise InteractiveAuthIncompleteError(ret) - @defer.inlineCallbacks - def add_oob_auth(self, stagetype, authdict, clientip): + async def add_oob_auth( + self, stagetype: str, authdict: Dict[str, Any], clientip: str + ) -> bool: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -287,19 +423,15 @@ class AuthHandler(BaseHandler): if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info(authdict["session"]) - if "creds" not in sess: - sess["creds"] = {} - creds = sess["creds"] - - result = yield self.checkers[stagetype](authdict, clientip) + result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: - creds[stagetype] = result - self._save_session(sess) + await self.store.mark_ui_auth_stage_complete( + authdict["session"], stagetype, result + ) return True return False - def get_session_id(self, clientdict): + def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]: """ Gets the session ID for a client given the client dictionary @@ -307,7 +439,7 @@ class AuthHandler(BaseHandler): clientdict: The dictionary sent by the client in the request Returns: - str|None: The string session ID the client sent. If the client did + The string session ID the client sent. If the client did not send a session ID, returns None. """ sid = None @@ -317,43 +449,57 @@ class AuthHandler(BaseHandler): sid = authdict["session"] return sid - def set_session_data(self, session_id, key, value): + async def set_session_data(self, session_id: str, key: str, value: Any) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - value (any): The data to store + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store """ - sess = self._get_session_info(session_id) - sess.setdefault("serverdict", {})[key] = value - self._save_session(sess) + try: + await self.store.set_ui_auth_session_data(session_id, key, value) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - def get_session_data(self, session_id, key, default=None): + async def get_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: """ Retrieve data stored with set_session_data Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - default (any): Value to return if the key has not been set + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set """ - sess = self._get_session_info(session_id) - return sess.setdefault("serverdict", {}).get(key, default) + try: + return await self.store.get_ui_auth_session_data(session_id, key, default) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - @defer.inlineCallbacks - def _check_auth_dict(self, authdict, clientip): + async def _expire_old_sessions(self): + """ + Invalidate any user interactive authentication sessions that have expired. + """ + now = self._clock.time_msec() + expiration_time = now - self.SESSION_EXPIRE_MS + await self.store.delete_old_ui_auth_sessions(expiration_time) + + async def _check_auth_dict( + self, authdict: Dict[str, Any], clientip: str + ) -> Union[Dict[str, Any], str]: """Attempt to validate the auth dict provided by a client Args: - authdict (object): auth dict provided by the client - clientip (str): IP address of the client + authdict: auth dict provided by the client + clientip: IP address of the client Returns: - Deferred: result of the stage verification. + Result of the stage verification. Raises: StoreError if there was a problem accessing the database @@ -363,7 +509,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - res = yield checker(authdict, clientip=clientip) + res = await checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -373,132 +519,13 @@ class AuthHandler(BaseHandler): if user_id is None: raise SynapseError(400, "", Codes.MISSING_PARAM) - (canonical_id, callback) = yield self.validate_login(user_id, authdict) + (canonical_id, callback) = await self.validate_login(user_id, authdict) return canonical_id - @defer.inlineCallbacks - def _check_recaptcha(self, authdict, clientip, **kwargs): - try: - user_response = authdict["response"] - except KeyError: - # Client tried to provide captcha but didn't give the parameter: - # bad request. - raise LoginError( - 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED - ) - - logger.info( - "Submitting recaptcha response %s with remoteip %s", user_response, clientip - ) - - # TODO: get this from the homeserver rather than creating a new one for - # each request - try: - client = self.hs.get_simple_http_client() - resp_body = yield client.post_urlencoded_get_json( - self.hs.config.recaptcha_siteverify_api, - args={ - "secret": self.hs.config.recaptcha_private_key, - "response": user_response, - "remoteip": clientip, - }, - ) - except PartialDownloadError as pde: - # Twisted is silly - data = pde.response - resp_body = json.loads(data) - - if "success" in resp_body: - # Note that we do NOT check the hostname here: we explicitly - # intend the CAPTCHA to be presented by whatever client the - # user is using, we just care that they have completed a CAPTCHA. - logger.info( - "%s reCAPTCHA from hostname %s", - "Successful" if resp_body["success"] else "Failed", - resp_body.get("hostname"), - ) - if resp_body["success"]: - return True - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid("email", authdict, **kwargs) - - def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid("msisdn", authdict) - - def _check_dummy_auth(self, authdict, **kwargs): - return defer.succeed(True) - - def _check_terms_auth(self, authdict, **kwargs): - return defer.succeed(True) - - @defer.inlineCallbacks - def _check_threepid(self, medium, authdict, **kwargs): - if "threepid_creds" not in authdict: - raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - - threepid_creds = authdict["threepid_creds"] - - identity_handler = self.hs.get_handlers().identity_handler - - logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - if medium == "email": - threepid = yield identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds - ) - elif medium == "msisdn": - threepid = yield identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds - ) - else: - raise SynapseError(400, "Unrecognized threepid medium: %s" % (medium,)) - elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - row = yield self.store.get_threepid_validation_session( - medium, - threepid_creds["client_secret"], - sid=threepid_creds["sid"], - validated=True, - ) - - threepid = ( - { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } - if row - else None - ) - - if row: - # Valid threepid returned, delete from the db - yield self.store.delete_threepid_session(threepid_creds["sid"]) - else: - raise SynapseError( - 400, "Password resets are not enabled on this homeserver" - ) - - if not threepid: - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - if threepid["medium"] != medium: - raise LoginError( - 401, - "Expecting threepid of type '%s', got '%s'" - % (medium, threepid["medium"]), - errcode=Codes.UNAUTHORIZED, - ) - - threepid["threepid_creds"] = authdict["threepid_creds"] - - return threepid - - def _get_params_recaptcha(self): + def _get_params_recaptcha(self) -> dict: return {"public_key": self.hs.config.recaptcha_public_key} - def _get_params_terms(self): + def _get_params_terms(self) -> dict: return { "policies": { "privacy_policy": { @@ -515,7 +542,9 @@ class AuthHandler(BaseHandler): } } - def _auth_dict_for_flows(self, flows, session): + def _auth_dict_for_flows( + self, flows: List[List[str]], session_id: str, + ) -> Dict[str, Any]: public_flows = [] for f in flows: public_flows.append(f) @@ -525,7 +554,7 @@ class AuthHandler(BaseHandler): LoginType.TERMS: self._get_params_terms, } - params = {} + params = {} # type: Dict[str, Any] for f in public_flows: for stage in f: @@ -533,25 +562,14 @@ class AuthHandler(BaseHandler): params[stage] = get_params[stage]() return { - "session": session["id"], + "session": session_id, "flows": [{"stages": f} for f in public_flows], "params": params, } - def _get_session_info(self, session_id): - if session_id not in self.sessions: - session_id = None - - if not session_id: - # create a new session - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - self.sessions[session_id] = {"id": session_id} - - return self.sessions[session_id] - - @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): + async def get_access_token_for_user_id( + self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] + ): """ Creates a new access token for the user with the given user ID. @@ -561,11 +579,11 @@ class AuthHandler(BaseHandler): The device will be recorded in the table if it is not there already. Args: - user_id (str): canonical User ID - device_id (str|None): the device ID to associate with the tokens. + user_id: canonical User ID + device_id: the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - valid_until_ms (int|None): when the token is valid until. None for + valid_until_ms: when the token is valid until. None for no expiry. Returns: The access token for the user's session. @@ -579,10 +597,10 @@ class AuthHandler(BaseHandler): ) logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id, access_token, device_id, valid_until_ms ) @@ -592,52 +610,45 @@ class AuthHandler(BaseHandler): # device, so we double-check it here. if device_id is not None: try: - yield self.store.get_device(user_id, device_id) + await self.store.get_device(user_id, device_id) except StoreError: - yield self.store.delete_access_token(access_token) + await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") return access_token - @defer.inlineCallbacks - def check_user_exists(self, user_id): + async def check_user_exists(self, user_id: str) -> Optional[str]: """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. Args: - (unicode|bytes) user_id: complete @user:id + user_id: complete @user:id Returns: - defer.Deferred: (unicode) canonical_user_id, or None if zero or - multiple matches - - Raises: - LimitExceededError if the ratelimiter's login requests count for this - user is too high too proceed. - UserDeactivatedError if a user is found but is deactivated. + The canonical_user_id, or None if zero or multiple matches """ - self.ratelimit_login_per_account(user_id) - res = yield self._find_user_id_and_pwd_hash(user_id) + res = await self._find_user_id_and_pwd_hash(user_id) if res is not None: return res[0] return None - @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id): + async def _find_user_id_and_pwd_hash( + self, user_id: str + ) -> Optional[Tuple[str, str]]: """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. Returns: - tuple: A 2-tuple of `(canonical_user_id, password_hash)` - None: if there is not exactly one match + A 2-tuple of `(canonical_user_id, password_hash)` or `None` + if there is not exactly one match """ - user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + user_infos = await self.store.get_users_by_id_case_insensitive(user_id) result = None if not user_infos: - logger.warn("Attempted to login as %s but they do not exist", user_id) + logger.warning("Attempted to login as %s but they do not exist", user_id) elif len(user_infos) == 1: # a single match (possibly not exact) result = user_infos.popitem() @@ -646,7 +657,7 @@ class AuthHandler(BaseHandler): result = (user_id, user_infos[user_id]) else: # multiple matches, none of them exact - logger.warn( + logger.warning( "Attempted to login as %s but it matches more than one user " "inexactly: %r", user_id, @@ -654,7 +665,7 @@ class AuthHandler(BaseHandler): ) return result - def get_supported_login_types(self): + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API By default this is just 'm.login.password' (unless password_enabled is @@ -662,30 +673,29 @@ class AuthHandler(BaseHandler): other login types. Returns: - Iterable[str]: login types + login types """ return self._supported_login_types - @defer.inlineCallbacks - def validate_login(self, username, login_submission): + async def validate_login( + self, username: str, login_submission: Dict[str, Any] + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate m.login.password auth types. Args: - username (str): username supplied by the user - login_submission (dict): the whole of the login submission + username: username supplied by the user + login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: - Deferred[str, func]: canonical user id, and optional callback + A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued Raises: StoreError if there was a problem accessing the database SynapseError if there was a problem with the request LoginError if there was an authentication problem. - LimitExceededError if the ratelimiter's login requests count for this - user is too high too proceed. """ if username.startswith("@"): @@ -693,8 +703,6 @@ class AuthHandler(BaseHandler): else: qualified_user_id = UserID(username, self.hs.hostname).to_string() - self.ratelimit_login_per_account(qualified_user_id) - login_type = login_submission.get("type") known_login_type = False @@ -711,7 +719,7 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password(qualified_user_id, password) + is_valid = await provider.check_password(qualified_user_id, password) if is_valid: return qualified_user_id, None @@ -743,7 +751,7 @@ class AuthHandler(BaseHandler): % (login_type, missing_fields), ) - result = yield provider.check_auth(username, login_type, login_dict) + result = await provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -752,8 +760,8 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True - canonical_user_id = yield self._check_local_password( - qualified_user_id, password + canonical_user_id = await self._check_local_password( + qualified_user_id, password # type: ignore ) if canonical_user_id: @@ -762,32 +770,23 @@ class AuthHandler(BaseHandler): if not known_login_type: raise SynapseError(400, "Unknown login type %s" % login_type) - # unknown username or invalid password. - self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) - # We raise a 403 here, but note that if we're doing user-interactive # login, it turns all LoginErrors into a 401 anyway. raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) - @defer.inlineCallbacks - def check_password_provider_3pid(self, medium, address, password): + async def check_password_provider_3pid( + self, medium: str, address: str, password: str + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: """Check if a password provider is able to validate a thirdparty login Args: - medium (str): The medium of the 3pid (ex. email). - address (str): The address of the 3pid (ex. jdoe@example.com). - password (str): The password of the user. + medium: The medium of the 3pid (ex. email). + address: The address of the 3pid (ex. jdoe@example.com). + password: The password of the user. Returns: - Deferred[(str|None, func|None)]: A tuple of `(user_id, - callback)`. If authentication is successful, `user_id` is a `str` - containing the authenticated, canonical user ID. `callback` is + A tuple of `(user_id, callback)`. If authentication is successful, + `user_id`is the authenticated, canonical user ID. `callback` is then either a function to be later run after the server has completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. @@ -799,7 +798,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth(medium, address, password) + result = await provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -809,43 +808,36 @@ class AuthHandler(BaseHandler): return None, None - @defer.inlineCallbacks - def _check_local_password(self, user_id, password): + async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are multiple inexact matches. Args: - user_id (unicode): complete @user:id - password (unicode): the provided password + user_id: complete @user:id + password: the provided password Returns: - Deferred[unicode] the canonical_user_id, or Deferred[None] if - unknown user/bad password - - Raises: - LimitExceededError if the ratelimiter's login requests count for this - user is too high too proceed. + The canonical_user_id, or None if unknown user/bad password """ - lookupres = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = await self._find_user_id_and_pwd_hash(user_id) if not lookupres: return None (user_id, password_hash) = lookupres # If the password hash is None, the account has likely been deactivated if not password_hash: - deactivated = yield self.store.get_user_deactivated_status(user_id) + deactivated = await self.store.get_user_deactivated_status(user_id) if deactivated: raise UserDeactivatedError("This account has been deactivated") - result = yield self.validate_hash(password, password_hash) + result = await self.validate_hash(password, password_hash) if not result: - logger.warn("Failed password login for user %s", user_id) + logger.warning("Failed password login for user %s", user_id) return None return user_id - @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token): + async def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -854,27 +846,24 @@ class AuthHandler(BaseHandler): auth_api.validate_macaroon(macaroon, "login", user_id) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - self.ratelimit_login_per_account(user_id) - yield self.auth.check_auth_blocking(user_id) + + await self.auth.check_auth_blocking(user_id) return user_id - @defer.inlineCallbacks - def delete_access_token(self, access_token): + async def delete_access_token(self, access_token: str): """Invalidate a single access token Args: - access_token (str): access token to be deleted + access_token: access token to be deleted - Returns: - Deferred """ - user_info = yield self.auth.get_user_by_access_token(access_token) - yield self.store.delete_access_token(access_token) + user_info = await self.auth.get_user_by_access_token(access_token) + await self.store.delete_access_token(access_token) # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - yield provider.on_logged_out( + await provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, @@ -882,27 +871,26 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( str(user_info["user"]), (user_info["token_id"],) ) - @defer.inlineCallbacks - def delete_access_tokens_for_user( - self, user_id, except_token_id=None, device_id=None + async def delete_access_tokens_for_user( + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, ): """Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str|None): access_token ID which should *not* be - deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_token ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - Returns: - Deferred """ - tokens_and_devices = yield self.store.user_delete_access_tokens( + tokens_and_devices = await self.store.user_delete_access_tokens( user_id, except_token_id=except_token_id, device_id=device_id ) @@ -910,19 +898,28 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: - yield provider.on_logged_out( + await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( user_id, (token_id for _, token_id, _ in tokens_and_devices) ) - @defer.inlineCallbacks - def add_threepid(self, user_id, medium, address, validated_at): + async def add_threepid( + self, user_id: str, medium: str, address: str, validated_at: int + ): + # check if medium has a valid value + if medium not in ["email", "msisdn"]: + raise SynapseError( + code=400, + msg=("'%s' is not a valid value for 'medium'" % (medium,)), + errcode=Codes.INVALID_PARAM, + ) + # 'Canonicalise' email addresses down to lower case. - # We've now moving towards the Home Server being the entity that + # We've now moving towards the homeserver being the entity that # is responsible for validating threepids used for resetting passwords # on accounts, so in future Synapse will gain knowledge of specific # types (mediums) of threepid. For now, we still use the existing @@ -933,26 +930,26 @@ class AuthHandler(BaseHandler): if medium == "email": address = address.lower() - yield self.store.user_add_threepid( + await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) - @defer.inlineCallbacks - def delete_threepid(self, user_id, medium, address, id_server=None): + async def delete_threepid( + self, user_id: str, medium: str, address: str, id_server: Optional[str] = None + ) -> bool: """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. Args: - user_id (str) - medium (str) - address (str) - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to remove the 3pid from. + medium: The medium of the 3pid being removed: "email" or "msisdn". + address: The 3pid address to remove. + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). - Returns: - Deferred[bool]: Returns True if successfully unbound the 3pid on + Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the unbind API. """ @@ -962,27 +959,21 @@ class AuthHandler(BaseHandler): address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - result = yield identity_handler.try_unbind_threepid( + result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid(user_id, medium, address) + await self.store.user_delete_threepid(user_id, medium, address) return result - def _save_session(self, session): - # TODO: Persistent storage - logger.debug("Saving session %s", session) - session["last_used"] = self.hs.get_clock().time_msec() - self.sessions[session["id"]] = session - - def hash(self, password): + async def hash(self, password: str) -> str: """Computes a secure hash of password. Args: - password (unicode): Password to hash. + password: Password to hash. Returns: - Deferred(unicode): Hashed password. + Hashed password. """ def _do_hash(): @@ -994,17 +985,19 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password, stored_hash): + async def validate_hash( + self, password: str, stored_hash: Union[bytes, str] + ) -> bool: """Validates that self.hash(password) == stored_hash. Args: - password (unicode): Password to hash. - stored_hash (bytes): Expected hash value. + password: Password to hash. + stored_hash: Expected hash value. Returns: - Deferred(bool): Whether self.hash(password) == stored_hash. + Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): @@ -1020,46 +1013,150 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: - return defer.succeed(False) + return False - def ratelimit_login_per_account(self, user_id): - """Checks whether the process must be stopped because of ratelimiting. + async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + """ + Get the HTML for the SSO redirect confirmation page. + + Args: + redirect_url: The URL to redirect to the SSO provider. + session_id: The user interactive authentication session ID. - Checks against two ratelimiters: the generic one for login attempts per - account and the one specific to failed attempts. + Returns: + The HTML to render. + """ + try: + session = await self.store.get_ui_auth_session(session_id) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) + return self._sso_auth_confirm_template.render( + description=session.description, redirect_url=redirect_url, + ) + + async def complete_sso_ui_auth( + self, registered_user_id: str, session_id: str, request: SynapseRequest, + ): + """Having figured out a mxid for this user, complete the HTTP request Args: - user_id (unicode): complete @user:id + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. + """ + # Mark the stage of the authentication as successful. + # Save the user who authenticated with SSO, this will be used to ensure + # that the account be modified is also the person who logged in. + await self.store.mark_ui_auth_stage_complete( + session_id, LoginType.SSO, registered_user_id + ) - Raises: - LimitExceededError if one of the ratelimiters' login requests count - for this user is too high too proceed. + # Render the HTML and return. + html_bytes = self._sso_auth_success_template.encode("utf-8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + + async def complete_sso_login( + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. """ - self._failed_attempts_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, + # If the account has been deactivated, do not proceed with the login + # flow. + deactivated = await self.store.get_user_deactivated_status(registered_user_id) + if deactivated: + html_bytes = self._sso_account_deactivated_template.encode("utf-8") + + request.setResponseCode(403) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + request.write(html_bytes) + finish_request(request) + return + + self._complete_sso_login(registered_user_id, request, client_redirect_url) + + def _complete_sso_login( + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + ): + """ + The synchronous portion of complete_sso_login. + + This exists purely for backwards compatibility of synapse.module_api.ModuleApi. + """ + # Create a login token + login_token = self.macaroon_gen.generate_short_term_login_token( + registered_user_id ) - self._account_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - update=True, + # Append the login token to the original redirect URL (i.e. with its query + # parameters kept intact) to build the URL to which the template needs to + # redirect the users once they have clicked on the confirmation link. + redirect_url = self.add_query_param_to_url( + client_redirect_url, "loginToken", login_token ) + # if the client is whitelisted, we can redirect straight to it + if client_redirect_url.startswith(self._whitelisted_sso_clients): + request.redirect(redirect_url) + finish_request(request) + return + + # Otherwise, serve the redirect confirmation page. + + # Remove the query parameters from the redirect URL to get a shorter version of + # it. This is only to display a human-readable URL in the template, but not the + # URL we redirect users to. + redirect_url_no_params = client_redirect_url.split("?")[0] + + html_bytes = self._sso_redirect_confirm_template.render( + display_url=redirect_url_no_params, + redirect_url=redirect_url, + server_name=self._server_name, + ).encode("utf-8") + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + request.write(html_bytes) + finish_request(request) + + @staticmethod + def add_query_param_to_url(url: str, param_name: str, param: Any): + url_parts = list(urllib.parse.urlparse(url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({param_name: param}) + url_parts[4] = urllib.parse.urlencode(query) + return urllib.parse.urlunparse(url_parts) + @attr.s class MacaroonGenerator(object): hs = attr.ib() - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token( + self, user_id: str, extra_caveats: Optional[List[str]] = None + ) -> str: extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") @@ -1072,16 +1169,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - """ - - Args: - user_id (unicode): - duration_in_ms (int): - - Returns: - unicode - """ + def generate_short_term_login_token( + self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + ) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() @@ -1089,12 +1179,12 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() - def generate_delete_pusher_token(self, user_id): + def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = delete_pusher") return macaroon.serialize() - def _generate_base_macaroon(self, user_id): + def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon: macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py new file mode 100644 index 0000000000..64aaa1335c --- /dev/null +++ b/synapse/handlers/cas_handler.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import xml.etree.ElementTree as ET +from typing import Dict, Optional, Tuple + +from six.moves import urllib + +from twisted.web.client import PartialDownloadError + +from synapse.api.errors import Codes, LoginError +from synapse.http.site import SynapseRequest +from synapse.types import UserID, map_username_to_mxid_localpart + +logger = logging.getLogger(__name__) + + +class CasHandler: + """ + Utility class for to handle the response from a CAS SSO service. + + Args: + hs (synapse.server.HomeServer) + """ + + def __init__(self, hs): + self._hostname = hs.hostname + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + self._cas_displayname_attribute = hs.config.cas_displayname_attribute + self._cas_required_attributes = hs.config.cas_required_attributes + + self._http_client = hs.get_proxied_http_client() + + def _build_service_param(self, args: Dict[str, str]) -> str: + """ + Generates a value to use as the "service" parameter when redirecting or + querying the CAS service. + + Args: + args: Additional arguments to include in the final redirect URL. + + Returns: + The URL to use as a "service" parameter. + """ + return "%s%s?%s" % ( + self._cas_service_url, + "/_matrix/client/r0/login/cas/ticket", + urllib.parse.urlencode(args), + ) + + async def _validate_ticket( + self, ticket: str, service_args: Dict[str, str] + ) -> Tuple[str, Optional[str]]: + """ + Validate a CAS ticket with the server, parse the response, and return the user and display name. + + Args: + ticket: The CAS ticket from the client. + service_args: Additional arguments to include in the service URL. + Should be the same as those passed to `get_redirect_url`. + """ + uri = self._cas_server_url + "/proxyValidate" + args = { + "ticket": ticket, + "service": self._build_service_param(service_args), + } + try: + body = await self._http_client.get_raw(uri, args) + except PartialDownloadError as pde: + # Twisted raises this error if the connection is closed, + # even if that's being used old-http style to signal end-of-data + body = pde.response + + user, attributes = self._parse_cas_response(body) + displayname = attributes.pop(self._cas_displayname_attribute, None) + + for required_attribute, required_value in self._cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + return user, displayname + + def _parse_cas_response( + self, cas_response_body: str + ) -> Tuple[str, Dict[str, Optional[str]]]: + """ + Retrieve the user and other parameters from the CAS response. + + Args: + cas_response_body: The response from the CAS query. + + Returns: + A tuple of the user and a mapping of other attributes. + """ + user = None + attributes = {} + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = root[0].tag.endswith("authenticationSuccess") + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + except Exception: + logger.exception("Error parsing CAS response") + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) + return user, attributes + + def get_redirect_url(self, service_args: Dict[str, str]) -> str: + """ + Generates a URL for the CAS server where the client should be redirected. + + Args: + service_args: Additional arguments to include in the final redirect URL. + + Returns: + The URL to redirect the client to. + """ + args = urllib.parse.urlencode( + {"service": self._build_service_param(service_args)} + ) + + return "%s/login?%s" % (self._cas_server_url, args) + + async def handle_ticket( + self, + request: SynapseRequest, + ticket: str, + client_redirect_url: Optional[str], + session: Optional[str], + ) -> None: + """ + Called once the user has successfully authenticated with the SSO. + Validates a CAS ticket sent by the client and completes the auth process. + + If the user interactive authentication session is provided, marks the + UI Auth session as complete, then returns an HTML page notifying the + user they are done. + + Otherwise, this registers the user if necessary, and then returns a + redirect (with a login token) to the client. + + Args: + request: the incoming request from the browser. We'll + respond to it with a redirect or an HTML page. + + ticket: The CAS ticket provided by the client. + + client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given. + This should be the same as the redirectUrl from the original `/login/sso/redirect` request. + + session: The session parameter from the `/cas/ticket` HTTP request, if given. + This should be the UI Auth session id. + """ + args = {} + if client_redirect_url: + args["redirectUrl"] = client_redirect_url + if session: + args["session"] = session + username, user_display_name = await self._validate_ticket(ticket, args) + + localpart = map_username_to_mxid_localpart(username) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + + if session: + await self._auth_handler.complete_sso_ui_auth( + registered_user_id, session, request, + ) + + else: + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name + ) + + await self._auth_handler.complete_sso_login( + registered_user_id, request, client_redirect_url + ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 5f804d1f13..2afb390a92 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, create_requester @@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled - @defer.inlineCallbacks - def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account(self, user_id, erase_data, id_server=None): """Deactivate a user's account Args: @@ -73,10 +70,12 @@ class DeactivateAccountHandler(BaseHandler): # unbinding identity_server_supports_unbinding = True - threepids = yield self.store.user_get_threepids(user_id) + # Retrieve the 3PIDs this user has bound to an identity server + threepids = await self.store.user_get_bound_threepids(user_id) + for threepid in threepids: try: - result = yield self._identity_handler.try_unbind_threepid( + result = await self._identity_handler.try_unbind_threepid( user_id, { "medium": threepid["medium"], @@ -89,44 +88,83 @@ class DeactivateAccountHandler(BaseHandler): # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") - yield self.store.user_delete_threepid( + await self.store.user_delete_threepid( user_id, threepid["medium"], threepid["address"] ) + # Remove all 3PIDs this user has bound to the homeserver + await self.store.user_delete_threepids(user_id) + # delete any devices belonging to the user, which will also # delete corresponding access tokens. - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # then delete any remaining access tokens which weren't associated with # a device. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) - yield self.store.user_set_password_hash(user_id, None) + await self.store.user_set_password_hash(user_id, None) # Add the user to a table of users pending deactivation (ie. # removal from all the rooms they're a member of) - yield self.store.add_user_pending_deactivation(user_id) + await self.store.add_user_pending_deactivation(user_id) # delete from user directory - yield self.user_directory_handler.handle_user_deactivated(user_id) + await self.user_directory_handler.handle_user_deactivated(user_id) # Mark the user as erased, if they asked for that if erase_data: logger.info("Marking %s as erased", user_id) - yield self.store.mark_user_erased(user_id) + await self.store.mark_user_erased(user_id) # Now start the process that goes through that list and # parts users from rooms (if it isn't already running) self._start_user_parting() + # Reject all pending invites for the user, so that the user doesn't show up in the + # "invited" section of rooms' members list. + await self._reject_pending_invites_for_user(user_id) + # Remove all information on the user from the account_validity table. if self._account_validity_enabled: - yield self.store.delete_account_validity_for_user(user_id) + await self.store.delete_account_validity_for_user(user_id) # Mark the user as deactivated. - yield self.store.set_user_deactivated_status(user_id, True) + await self.store.set_user_deactivated_status(user_id, True) return identity_server_supports_unbinding + async def _reject_pending_invites_for_user(self, user_id): + """Reject pending invites addressed to a given user ID. + + Args: + user_id (str): The user ID to reject pending invites for. + """ + user = UserID.from_string(user_id) + pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) + + for room in pending_invites: + try: + await self._room_member_handler.update_membership( + create_requester(user), + user, + room.room_id, + "leave", + ratelimit=False, + require_consent=False, + ) + logger.info( + "Rejected invite for deactivated user %r in room %r", + user_id, + room.room_id, + ) + except Exception: + logger.exception( + "Failed to reject invite for user %r in room %r:" + " ignoring and continuing", + user_id, + room.room_id, + ) + def _start_user_parting(self): """ Start the process that goes through the table of users @@ -138,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler): if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - @defer.inlineCallbacks - def _user_parter_loop(self): + async def _user_parter_loop(self): """Loop that parts deactivated users from rooms Returns: @@ -149,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler): logger.info("Starting user parter") try: while True: - user_id = yield self.store.get_user_pending_deactivation() + user_id = await self.store.get_user_pending_deactivation() if user_id is None: break logger.info("User parter parting %r", user_id) - yield self._part_user(user_id) - yield self.store.del_user_pending_deactivation(user_id) + await self._part_user(user_id) + await self.store.del_user_pending_deactivation(user_id) logger.info("User parter finished parting %r", user_id) logger.info("User parter finished: stopping") finally: self._user_parter_running = False - @defer.inlineCallbacks - def _part_user(self, user_id): + async def _part_user(self, user_id): """Causes the given user_id to leave all the rooms they're joined to Returns: @@ -169,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler): """ user = UserID.from_string(user_id) - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) for room_id in rooms_for_user: logger.info("User parter parting %r from %r", user_id, room_id) try: - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( create_requester(user), user, room_id, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 71a8f33da3..230d170258 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Dict, Optional from six import iteritems, itervalues @@ -24,9 +27,15 @@ from synapse.api.errors import ( FederationDeniedError, HttpResponseException, RequestSendFailed, + SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import ( + RoomStreamToken, + get_domain_from_id, + get_verify_key_from_cross_signing_key, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -37,6 +46,8 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +MAX_DEVICE_DISPLAY_NAME_LEN = 100 + class DeviceWorkerHandler(BaseHandler): def __init__(self, hs): @@ -44,6 +55,7 @@ class DeviceWorkerHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() @trace @@ -119,8 +131,14 @@ class DeviceWorkerHandler(BaseHandler): users_who_share_room = yield self.store.get_users_who_share_room_with_user( user_id ) + + tracked_users = set(users_who_share_room) + + # Always tell the user about their own devices + tracked_users.add(user_id) + changed = yield self.store.get_users_whose_devices_changed( - from_token.device_list_key, users_who_share_room + from_token.device_list_key, tracked_users ) # Then work out if any users have since joined @@ -176,7 +194,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -222,6 +240,22 @@ class DeviceWorkerHandler(BaseHandler): return result + @defer.inlineCallbacks + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing" + ) + + return { + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + "master_key": master_key, + "self_signing_key": self_signing_key, + } + class DeviceHandler(DeviceWorkerHandler): def __init__(self, hs): @@ -236,9 +270,6 @@ class DeviceHandler(DeviceWorkerHandler): federation_registry.register_edu_handler( "m.device_list_update", self.device_list_updater.incoming_device_list_update ) - federation_registry.register_query_handler( - "user_devices", self.on_federation_query_user_devices - ) hs.get_distributor().observe("user_left_room", self.user_left_room) @@ -313,8 +344,10 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @@ -366,8 +399,10 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id @@ -388,9 +423,18 @@ class DeviceHandler(DeviceWorkerHandler): defer.Deferred: """ + # Reject a new displayname which is too long. + new_display_name = content.get("display_name") + if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" + % (MAX_DEVICE_DISPLAY_NAME_LEN,), + ) + try: yield self.store.update_device( - user_id, device_id, new_display_name=content.get("display_name") + user_id, device_id, new_display_name=new_display_name ) yield self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: @@ -428,7 +472,11 @@ class DeviceHandler(DeviceWorkerHandler): room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) + # specify the user ID too since the user should always get their own device list + # updates, even if they aren't in any rooms. + yield self.notifier.on_new_event( + "device_list_key", position, users=[user_id], rooms=room_ids + ) if hosts: logger.info( @@ -439,9 +487,19 @@ class DeviceHandler(DeviceWorkerHandler): log_kv({"message": "sent device update to host", "host": host}) @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - return {"user_id": user_id, "stream_id": stream_id, "devices": devices} + def notify_user_signature_update(self, from_user_id, user_ids): + """Notify a user that they have made new signatures of other users. + + Args: + from_user_id (str): the user who made the signature + user_ids (list[str]): the users IDs that have new signatures + """ + + position = yield self.store.add_user_signature_change_to_streams( + from_user_id, user_ids + ) + + self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -483,6 +541,15 @@ class DeviceListUpdater(object): iterable=True, ) + # Attempt to resync out of sync device lists every 30s. + self._resync_retry_in_progress = False + self.clock.looping_call( + run_as_background_process, + 30 * 1000, + func=self._maybe_retry_device_resync, + desc="_maybe_retry_device_resync", + ) + @trace @defer.inlineCallbacks def incoming_device_list_update(self, origin, edu_content): @@ -569,7 +636,13 @@ class DeviceListUpdater(object): # happens if we've missed updates. resync = yield self._need_to_do_resync(user_id, pending_updates) - logger.debug("Need to re-sync devices for %r? %r", user_id, resync) + if logger.isEnabledFor(logging.INFO): + logger.info( + "Received device list update for %s, requiring resync: %s. Devices: %s", + user_id, + resync, + ", ".join(u[0] for u in pending_updates), + ) if resync: yield self.user_device_resync(user_id) @@ -621,25 +694,83 @@ class DeviceListUpdater(object): return False @defer.inlineCallbacks - def user_device_resync(self, user_id): + def _maybe_retry_device_resync(self): + """Retry to resync device lists that are out of sync, except if another retry is + in progress. + """ + if self._resync_retry_in_progress: + return + + try: + # Prevent another call of this function to retry resyncing device lists so + # we don't send too many requests. + self._resync_retry_in_progress = True + # Get all of the users that need resyncing. + need_resync = yield self.store.get_user_ids_requiring_device_list_resync() + # Iterate over the set of user IDs. + for user_id in need_resync: + try: + # Try to resync the current user's devices list. + result = yield self.user_device_resync( + user_id=user_id, mark_failed_as_stale=False, + ) + + # user_device_resync only returns a result if it managed to + # successfully resync and update the database. Updating the table + # of users requiring resync isn't necessary here as + # user_device_resync already does it (through + # self.store.update_remote_device_list_cache). + if result: + logger.debug( + "Successfully resynced the device list for %s", user_id, + ) + except Exception as e: + # If there was an issue resyncing this user, e.g. if the remote + # server sent a malformed result, just log the error instead of + # aborting all the subsequent resyncs. + logger.debug( + "Could not resync the device list for %s: %s", user_id, e, + ) + finally: + # Allow future calls to retry resyncinc out of sync device lists. + self._resync_retry_in_progress = False + + @defer.inlineCallbacks + def user_device_resync(self, user_id, mark_failed_as_stale=True): """Fetches all devices for a user and updates the device cache with them. Args: user_id (str): The user's id whose device_list will be updated. + mark_failed_as_stale (bool): Whether to mark the user's device list as stale + if the attempt to resync failed. Returns: Deferred[dict]: a dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid """ + logger.debug("Attempting to resync the device list for %s", user_id) log_kv({"message": "Doing resync to update device list."}) # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: result = yield self.federation.query_user_devices(origin, user_id) - except (NotRetryingDestination, RequestSendFailed, HttpResponseException): - # TODO: Remember that we are now out of sync and try again - # later - logger.warn("Failed to handle device list update for %s", user_id) + except NotRetryingDestination: + if mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + yield self.store.mark_remote_user_device_cache_as_stale(user_id) + + return + except (RequestSendFailed, HttpResponseException) as e: + logger.warning( + "Failed to handle device list update for %s: %s", user_id, e, + ) + + if mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + yield self.store.mark_remote_user_device_cache_as_stale(user_id) + # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -653,18 +784,29 @@ class DeviceListUpdater(object): logger.info(e) return except Exception as e: - # TODO: Remember that we are now out of sync and try again - # later set_tag("error", True) log_kv( {"message": "Exception raised by federation request", "exception": e} ) logger.exception("Failed to handle device list update for %s", user_id) + + if mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + yield self.store.mark_remote_user_device_cache_as_stale(user_id) + return log_kv({"result": result}) stream_id = result["stream_id"] devices = result["devices"] + # Get the master key and the self-signing key for this user if provided in the + # response (None if not in the response). + # The response will not contain the user signing key, as this key is only used by + # its owner, thus it doesn't make sense to send it over federation. + master_key = result.get("master_key") + self_signing_key = result.get("self_signing_key") + # If the remote server has more than ~1000 devices for this user # we assume that something is going horribly wrong (e.g. a bot # that logs in and creates a new device every time it tries to @@ -677,7 +819,7 @@ class DeviceListUpdater(object): # up on storing the total list of devices and only handle the # delta instead. if len(devices) > 1000: - logger.warn( + logger.warning( "Ignoring device list snapshot for %s as it has >1K devs (%d)", user_id, len(devices), @@ -694,10 +836,54 @@ class DeviceListUpdater(object): yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] + + # Handle cross-signing keys. + cross_signing_device_ids = yield self.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + cross_signing_device_ids + yield self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. - self._seen_updates[user_id] = set([stream_id]) + self._seen_updates[user_id] = {stream_id} defer.returnValue(result) + + @defer.inlineCallbacks + def process_cross_signing_key_update( + self, + user_id: str, + master_key: Optional[Dict[str, Any]], + self_signing_key: Optional[Dict[str, Any]], + ) -> list: + """Process the given new master and self-signing key for the given remote user. + + Args: + user_id: The ID of the user these keys are for. + master_key: The dict of the cross-signing master key as returned by the + remote server. + self_signing_key: The dict of the cross-signing self-signing key as returned + by the remote server. + + Return: + The device IDs for the given keys. + """ + device_ids = [] + + if master_key: + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) + device_ids.append(verify_key.version) + + return device_ids diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 0043cbea17..05c4b3eec0 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -14,12 +14,14 @@ # limitations under the License. import logging +from typing import Any, Dict from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( get_active_span_text_map, log_kv, @@ -47,12 +49,14 @@ class DeviceMessageHandler(object): "m.direct_to_device", self.on_direct_to_device_edu ) + self._device_list_updater = hs.get_device_handler().device_list_updater + @defer.inlineCallbacks def on_direct_to_device_edu(self, origin, content): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): - logger.warn( + logger.warning( "Dropping device message from %r with spoofed sender %r", origin, sender_user_id, @@ -65,6 +69,9 @@ class DeviceMessageHandler(object): logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") + if not by_device: + continue + messages_by_device = { device_id: { "content": message_content, @@ -73,8 +80,11 @@ class DeviceMessageHandler(object): } for device_id, message_content in by_device.items() } - if messages_by_device: - local_messages[user_id] = messages_by_device + local_messages[user_id] = messages_by_device + + yield self._check_for_unknown_devices( + message_type, sender_user_id, by_device + ) stream_id = yield self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages @@ -85,6 +95,55 @@ class DeviceMessageHandler(object): ) @defer.inlineCallbacks + def _check_for_unknown_devices( + self, + message_type: str, + sender_user_id: str, + by_device: Dict[str, Dict[str, Any]], + ): + """Checks inbound device messages for unkown remote devices, and if + found marks the remote cache for the user as stale. + """ + + if message_type != "m.room_key_request": + return + + # Get the sending device IDs + requesting_device_ids = set() + for message_content in by_device.values(): + device_id = message_content.get("requesting_device_id") + requesting_device_ids.add(device_id) + + # Check if we are tracking the devices of the remote user. + room_ids = yield self.store.get_rooms_for_user(sender_user_id) + if not room_ids: + logger.info( + "Received device message from remote device we don't" + " share a room with: %s %s", + sender_user_id, + requesting_device_ids, + ) + return + + # If we are tracking check that we know about the sending + # devices. + cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id) + + unknown_devices = requesting_device_ids - set(cached_devices) + if unknown_devices: + logger.info( + "Received device message from remote device not in our cache: %s %s", + sender_user_id, + unknown_devices, + ) + yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + + # Immediately attempt a resync in the background + run_in_background( + self._device_list_updater.user_device_resync, sender_user_id + ) + + @defer.inlineCallbacks def send_device_message(self, sender_user_id, message_type, messages): set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 526379c6f7..f2f16b1e43 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging import string +from typing import Iterable, List, Optional from twisted.internet import defer @@ -28,7 +28,8 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.types import RoomAlias, UserID, get_domain_from_id +from synapse.appservice import ApplicationService +from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -55,7 +56,13 @@ class DirectoryHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() @defer.inlineCallbacks - def _create_association(self, room_alias, room_id, servers=None, creator=None): + def _create_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Optional[Iterable[str]] = None, + creator: Optional[str] = None, + ): # general association creation for both human users and app services for wchar in string.whitespace: @@ -70,7 +77,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Check if there is a current association. if not servers: users = yield self.state.get_current_users_in_room(room_id) - servers = set(get_domain_from_id(u) for u in users) + servers = {get_domain_from_id(u) for u in users} if not servers: raise SynapseError(400, "Failed to get server list") @@ -79,26 +86,22 @@ class DirectoryHandler(BaseHandler): room_alias, room_id, servers, creator=creator ) - @defer.inlineCallbacks - def create_association( + async def create_association( self, - requester, - room_alias, - room_id, - servers=None, - send_event=True, - check_membership=True, + requester: Requester, + room_alias: RoomAlias, + room_id: str, + servers: Optional[List[str]] = None, + check_membership: bool = True, ): """Attempt to create a new alias Args: - requester (Requester) - room_alias (RoomAlias) - room_id (str) - servers (list[str]|None): List of servers that others servers - should try and join via - send_event (bool): Whether to send an updated m.room.aliases event - check_membership (bool): Whether to check if the user is in the room + requester + room_alias + room_id + servers: Iterable of servers that others servers should try and join via + check_membership: Whether to check if the user is in the room before the alias can be set (if the server's config requires it). Returns: @@ -119,12 +122,16 @@ class DirectoryHandler(BaseHandler): if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, - "This application service has not reserved" " this kind of alias.", + "This application service has not reserved this kind of alias.", errcode=Codes.EXCLUSIVE, ) else: - if self.require_membership and check_membership: - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + # Server admins are not subject to the same constraints as normal + # users when creating an alias (e.g. being in the room). + is_admin = await self.auth.is_server_admin(requester.user) + + if (self.require_membership and check_membership) and not is_admin: + rooms_for_user = await self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( 403, "You must be in the room to create an alias for it" @@ -141,7 +148,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias(room_alias, user_id=user_id) + can_create = await self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( 400, @@ -149,23 +156,17 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - yield self._create_association(room_alias, room_id, servers, creator=user_id) - if send_event: - yield self.send_room_alias_update_event(requester, room_id) + await self._create_association(room_alias, room_id, servers, creator=user_id) - @defer.inlineCallbacks - def delete_association(self, requester, room_alias, send_event=True): + async def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call delete_appservice_association) Args: - requester (Requester): - room_alias (RoomAlias): - send_event (bool): Whether to send an updated m.room.aliases event. - Note that, if we delete the canonical alias, we will always attempt - to send an m.room.canonical_alias event + requester + room_alias Returns: Deferred[unicode]: room id that the alias used to point to @@ -181,7 +182,7 @@ class DirectoryHandler(BaseHandler): user_id = requester.user.to_string() try: - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -190,7 +191,7 @@ class DirectoryHandler(BaseHandler): if not can_delete: raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) + can_delete = await self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( 400, @@ -198,22 +199,19 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - room_id = yield self._delete_association(room_alias) + room_id = await self._delete_association(room_alias) try: - if send_event: - yield self.send_room_alias_update_event(requester, room_id) - - yield self._update_canonical_alias( - requester, requester.user.to_string(), room_id, room_alias - ) + await self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) return room_id @defer.inlineCallbacks - def delete_appservice_association(self, service, room_alias): + def delete_appservice_association( + self, service: ApplicationService, room_alias: RoomAlias + ): if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, @@ -223,7 +221,7 @@ class DirectoryHandler(BaseHandler): yield self._delete_association(room_alias) @defer.inlineCallbacks - def _delete_association(self, room_alias): + def _delete_association(self, room_alias: RoomAlias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -232,7 +230,7 @@ class DirectoryHandler(BaseHandler): return room_id @defer.inlineCallbacks - def get_association(self, room_alias): + def get_association(self, room_alias: RoomAlias): room_id = None if self.hs.is_mine(room_alias): result = yield self.get_association_from_room_alias(room_alias) @@ -250,7 +248,7 @@ class DirectoryHandler(BaseHandler): ignore_backoff=True, ) except CodeMessageException as e: - logging.warn("Error retrieving alias") + logging.warning("Error retrieving alias") if e.code == 404: result = None else: @@ -268,7 +266,7 @@ class DirectoryHandler(BaseHandler): ) users = yield self.state.get_current_users_in_room(room_id) - extra_servers = set(get_domain_from_id(u) for u in users) + extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. @@ -283,7 +281,7 @@ class DirectoryHandler(BaseHandler): def on_directory_query(self, args): room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): - raise SynapseError(400, "Room Alias is not hosted on this Home Server") + raise SynapseError(400, "Room Alias is not hosted on this homeserver") result = yield self.get_association_from_room_alias(room_alias) @@ -296,46 +294,58 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - @defer.inlineCallbacks - def send_room_alias_update_event(self, requester, room_id): - aliases = yield self.store.get_aliases_for_room(room_id) - - yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Aliases, - "state_key": self.hs.hostname, - "room_id": room_id, - "sender": requester.user.to_string(), - "content": {"aliases": aliases}, - }, - ratelimit=False, - ) - - @defer.inlineCallbacks - def _update_canonical_alias(self, requester, user_id, room_id, room_alias): - alias_event = yield self.state.get_current_state( + async def _update_canonical_alias( + self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias + ): + """ + Send an updated canonical alias event if the removed alias was set as + the canonical alias or listed in the alt_aliases field. + """ + alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" ) - alias_str = room_alias.to_string() - if not alias_event or alias_event.content.get("alias", "") != alias_str: + # There is no canonical alias, nothing to do. + if not alias_event: return - yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.CanonicalAlias, - "state_key": "", - "room_id": room_id, - "sender": user_id, - "content": {}, - }, - ratelimit=False, - ) + # Obtain a mutable version of the event content. + content = dict(alias_event.content) + send_update = False + + # Remove the alias property if it matches the removed alias. + alias_str = room_alias.to_string() + if alias_event.content.get("alias", "") == alias_str: + send_update = True + content.pop("alias", "") + + # Filter the alt_aliases property for the removed alias. Note that the + # value is not modified if alt_aliases is of an unexpected form. + alt_aliases = content.get("alt_aliases") + if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases: + send_update = True + alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + + if alt_aliases: + content["alt_aliases"] = alt_aliases + else: + del content["alt_aliases"] + + if send_update: + await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": room_id, + "sender": user_id, + "content": content, + }, + ratelimit=False, + ) @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): + def get_association_from_room_alias(self, room_alias: RoomAlias): result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -343,7 +353,7 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias, user_id=None): + def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have @@ -363,23 +373,41 @@ class DirectoryHandler(BaseHandler): # either no interested services, or no service with an exclusive lock return defer.succeed(True) - @defer.inlineCallbacks - def _user_can_delete_alias(self, alias, user_id): - creator = yield self.store.get_room_alias_creator(alias.to_string()) + async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + """Determine whether a user can delete an alias. + + One of the following must be true: + + 1. The user created the alias. + 2. The user is a server administrator. + 3. The user has a power-level sufficient to send a canonical alias event + for the current room. + + """ + creator = await self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True - is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) - return is_admin + # Resolve the alias to the corresponding room. + room_mapping = await self.get_association(alias) + room_id = room_mapping["room_id"] + if not room_id: + return False - @defer.inlineCallbacks - def edit_published_room_list(self, requester, room_id, visibility): + res = await self.auth.check_can_change_room_list( + room_id, UserID.from_string(user_id) + ) + return res + + async def edit_published_room_list( + self, requester: Requester, room_id: str, visibility: str + ): """Edit the entry of the room in the published room list. requester - room_id (str) - visibility (str): "public" or "private" + room_id + visibility: "public" or "private" """ user_id = requester.user.to_string() @@ -400,16 +428,24 @@ class DirectoryHandler(BaseHandler): 403, "This user is not permitted to publish rooms to the room list" ) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Unknown room") - yield self.auth.check_can_change_room_list(room_id, requester.user) + can_change_room_list = await self.auth.check_can_change_room_list( + room_id, requester.user + ) + if not can_change_room_list: + raise AuthError( + 403, + "This server requires you to be a moderator in the room to" + " edit its room list entry", + ) making_public = visibility == "public" if making_public: - room_aliases = yield self.store.get_aliases_for_room(room_id) - canonical_alias = yield self.store.get_canonical_alias_for_room(room_id) + room_aliases = await self.store.get_aliases_for_room(room_id) + canonical_alias = await self.store.get_canonical_alias_for_room(room_id) if canonical_alias: room_aliases.append(canonical_alias) @@ -421,20 +457,20 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") - yield self.store.set_room_is_public(room_id, making_public) + await self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks def edit_published_appservice_room_list( - self, appservice_id, network_id, room_id, visibility + self, appservice_id: str, network_id: str, room_id: str, visibility: str ): """Add or remove a room from the appservice/network specific public room list. Args: - appservice_id (str): ID of the appservice that owns the list - network_id (str): The ID of the network the list is associated with - room_id (str) - visibility (str): either "public" or "private" + appservice_id: ID of the appservice that owns the list + network_id: The ID of the network the list is associated with + room_id + visibility: either "public" or "private" """ if visibility not in ["public", "private"]: raise SynapseError(400, "Invalid visibility setting") @@ -442,3 +478,19 @@ class DirectoryHandler(BaseHandler): yield self.store.set_room_is_public_appservice( room_id, appservice_id, network_id, visibility == "public" ) + + async def get_aliases_for_room( + self, requester: Requester, room_id: str + ) -> List[str]: + """ + Get a list of the aliases that currently point to this room on this server + """ + # allow access to server admins and current members of the room + is_admin = await self.auth.is_server_admin(requester.user) + if not is_admin: + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + aliases = await self.store.get_aliases_for_room(room_id) + return aliases diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 056fb97acb..774a252619 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,15 +19,26 @@ import logging from six import iteritems +import attr from canonicaljson import encode_canonical_json, json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import SignatureVerifyException, verify_signed_json +from unpaddedbase64 import decode_base64 from twisted.internet import defer -from synapse.api.errors import CodeMessageException, SynapseError +from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.types import UserID, get_domain_from_id +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet +from synapse.types import ( + UserID, + get_domain_from_id, + get_verify_key_from_cross_signing_key, +) from synapse.util import unwrapFirstError +from synapse.util.async_helpers import Linearizer +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -40,16 +52,35 @@ class E2eKeysHandler(object): self.is_mine = hs.is_mine self.clock = hs.get_clock() + self._edu_updater = SigningKeyEduUpdater(hs, self) + + federation_registry = hs.get_federation_registry() + + self._is_master = hs.config.worker_app is None + if not self._is_master: + self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( + hs + ) + else: + # Only register this edu handler on master as it requires writing + # device updates to the db + # + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "org.matrix.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) + # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - hs.get_federation_registry().register_query_handler( + federation_registry.register_query_handler( "client_keys", self.on_federation_query_client_keys ) @trace @defer.inlineCallbacks - def query_devices(self, query_body, timeout): + def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -67,6 +98,11 @@ class E2eKeysHandler(object): } } } + + Args: + from_user_id (str): the user making the query. This is used when + adding cross-signing signatures to limit what signatures users + can see. """ device_keys_query = query_body.get("device_keys", {}) @@ -105,9 +141,10 @@ class E2eKeysHandler(object): else: query_list.append((user_id, None)) - user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache(query_list) - ) + ( + user_ids_not_in_cache, + remote_results, + ) = yield self.store.get_user_devices_from_cache(query_list) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) for device_id, device in iteritems(devices): @@ -125,6 +162,11 @@ class E2eKeysHandler(object): r = remote_queries_not_in_cache.setdefault(domain, {}) r[user_id] = remote_queries[user_id] + # Get cached cross-signing keys + cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + device_keys_query, from_user_id + ) + # Now fetch any devices that we don't have in our cache @trace @defer.inlineCallbacks @@ -132,8 +174,8 @@ class E2eKeysHandler(object): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for - specific user we will update the cache - with their device list.""" + specific user we will update the cache with their device list. + """ destination_query = remote_queries_not_in_cache[destination] @@ -160,12 +202,19 @@ class E2eKeysHandler(object): # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( - user_id - ) + if self._is_master: + user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_id + ) + else: + user_devices = yield self._user_device_resync_client( + user_id=user_id + ) + user_devices = user_devices["devices"] + user_results = results.setdefault(user_id, {}) for device in user_devices: - results[user_id] = {device["device_id"]: device["keys"]} + user_results[device["device_id"]] = device["keys"] user_ids_updated.append(user_id) except Exception as e: failures[destination] = _exception_to_failure(e) @@ -188,6 +237,16 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys + if "master_keys" in remote_result: + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master_keys"][user_id] = key + + if "self_signing_keys" in remote_result: + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing_keys"][user_id] = key + except Exception as e: failure = _exception_to_failure(e) failures[destination] = failure @@ -204,7 +263,58 @@ class E2eKeysHandler(object): ).addErrback(unwrapFirstError) ) - return {"device_keys": results, "failures": failures} + ret = {"device_keys": results, "failures": failures} + + ret.update(cross_signing_keys) + + return ret + + @defer.inlineCallbacks + def get_cross_signing_keys_from_cache(self, query, from_user_id): + """Get cross-signing keys for users from the database + + Args: + query (Iterable[string]) an iterable of user IDs. A dict whose keys + are user IDs satisfies this, so the query format used for + query_devices can be used here. + from_user_id (str): the user making the query. This is used when + adding cross-signing signatures to limit what signatures users + can see. + + Returns: + defer.Deferred[dict[str, dict[str, dict]]]: map from + (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key + """ + master_keys = {} + self_signing_keys = {} + user_signing_keys = {} + + user_ids = list(query) + + keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + + for user_id, user_info in keys.items(): + if user_info is None: + continue + if "master" in user_info: + master_keys[user_id] = user_info["master"] + if "self_signing" in user_info: + self_signing_keys[user_id] = user_info["self_signing"] + + if ( + from_user_id in keys + and keys[from_user_id] is not None + and "user_signing" in keys[from_user_id] + ): + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + + return { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } @trace @defer.inlineCallbacks @@ -248,16 +358,10 @@ class E2eKeysHandler(object): results = yield self.store.get_e2e_device_keys(local_query) - # Build the result structure, un-jsonify the results, and add the - # "unsigned" section + # Build the result structure for user_id, device_keys in results.items(): for device_id, device_info in device_keys.items(): - r = dict(device_info["keys"]) - r["unsigned"] = {} - display_name = device_info["device_display_name"] - if display_name is not None: - r["unsigned"]["device_display_name"] = display_name - result_dict[user_id][device_id] = r + result_dict[user_id][device_id] = device_info log_kv(results) return result_dict @@ -268,7 +372,16 @@ class E2eKeysHandler(object): """ device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) - return {"device_keys": res} + ret = {"device_keys": res} + + # add in the cross-signing keys + cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + device_keys_query, None + ) + + ret.update(cross_signing_keys) + + return ret @trace @defer.inlineCallbacks @@ -447,8 +560,633 @@ class E2eKeysHandler(object): log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + @defer.inlineCallbacks + def upload_signing_keys_for_user(self, user_id, keys): + """Upload signing keys for cross-signing + + Args: + user_id (string): the user uploading the keys + keys (dict[string, dict]): the signing keys + """ + + # if a master key is uploaded, then check it. Otherwise, load the + # stored master key, to check signatures on other keys + if "master_key" in keys: + master_key = keys["master_key"] + + _check_cross_signing_key(master_key, user_id, "master") + else: + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + + # if there is no master key, then we can't do anything, because all the + # other cross-signing keys need to be signed by the master key + if not master_key: + raise SynapseError(400, "No master key available", Codes.MISSING_PARAM) + + try: + master_key_id, master_verify_key = get_verify_key_from_cross_signing_key( + master_key + ) + except ValueError: + if "master_key" in keys: + # the invalid key came from the request + raise SynapseError(400, "Invalid master key", Codes.INVALID_PARAM) + else: + # the invalid key came from the database + logger.error("Invalid master key found for user %s", user_id) + raise SynapseError(500, "Invalid master key") + + # for the other cross-signing keys, make sure that they have valid + # signatures from the master key + if "self_signing_key" in keys: + self_signing_key = keys["self_signing_key"] + + _check_cross_signing_key( + self_signing_key, user_id, "self_signing", master_verify_key + ) + + if "user_signing_key" in keys: + user_signing_key = keys["user_signing_key"] + + _check_cross_signing_key( + user_signing_key, user_id, "user_signing", master_verify_key + ) + + # if everything checks out, then store the keys and send notifications + deviceids = [] + if "master_key" in keys: + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + deviceids.append(master_verify_key.version) + if "self_signing_key" in keys: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + try: + deviceids.append( + get_verify_key_from_cross_signing_key(self_signing_key)[1].version + ) + except ValueError: + raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) + if "user_signing_key" in keys: + yield self.store.set_e2e_cross_signing_key( + user_id, "user_signing", user_signing_key + ) + # the signature stream matches the semantics that we want for + # user-signing key updates: only the user themselves is notified of + # their own user-signing key updates + yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + + # master key and self-signing key updates match the semantics of device + # list updates: all users who share an encrypted room are notified + if len(deviceids): + yield self.device_handler.notify_device_update(user_id, deviceids) + + return {} + + @defer.inlineCallbacks + def upload_signatures_for_device_keys(self, user_id, signatures): + """Upload device signatures for cross-signing + + Args: + user_id (string): the user uploading the signatures + signatures (dict[string, dict[string, dict]]): map of users to + devices to signed keys. This is the submission from the user; an + exception will be raised if it is malformed. + Returns: + dict: response to be sent back to the client. The response will have + a "failures" key, which will be a dict mapping users to devices + to errors for the signatures that failed. + Raises: + SynapseError: if the signatures dict is not valid. + """ + failures = {} + + # signatures to be stored. Each item will be a SignatureListItem + signature_list = [] + + # split between checking signatures for own user and signatures for + # other users, since we verify them with different keys + self_signatures = signatures.get(user_id, {}) + other_signatures = {k: v for k, v in signatures.items() if k != user_id} + + self_signature_list, self_failures = yield self._process_self_signatures( + user_id, self_signatures + ) + signature_list.extend(self_signature_list) + failures.update(self_failures) + + other_signature_list, other_failures = yield self._process_other_signatures( + user_id, other_signatures + ) + signature_list.extend(other_signature_list) + failures.update(other_failures) + + # store the signature, and send the appropriate notifications for sync + logger.debug("upload signature failures: %r", failures) + yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + + self_device_ids = [item.target_device_id for item in self_signature_list] + if self_device_ids: + yield self.device_handler.notify_device_update(user_id, self_device_ids) + signed_users = [item.target_user_id for item in other_signature_list] + if signed_users: + yield self.device_handler.notify_user_signature_update( + user_id, signed_users + ) + + return {"failures": failures} + + @defer.inlineCallbacks + def _process_self_signatures(self, user_id, signatures): + """Process uploaded signatures of the user's own keys. + + Signatures of the user's own keys from this API come in two forms: + - signatures of the user's devices by the user's self-signing key, + - signatures of the user's master key by the user's devices. + + Args: + user_id (string): the user uploading the keys + signatures (dict[string, dict]): map of devices to signed keys + + Returns: + (list[SignatureListItem], dict[string, dict[string, dict]]): + a list of signatures to store, and a map of users to devices to failure + reasons + + Raises: + SynapseError: if the input is malformed + """ + signature_list = [] + failures = {} + if not signatures: + return signature_list, failures + + if not isinstance(signatures, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + + try: + # get our self-signing key to verify the signatures + ( + _, + self_signing_key_id, + self_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + + # get our master key, since we may have received a signature of it. + # We need to fetch it here so that we know what its key ID is, so + # that we can check if a signature that was sent is a signature of + # the master key or of a device + ( + master_key, + _, + master_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + + # fetch our stored devices. This is used to 1. verify + # signatures on the master key, and 2. to compare with what + # was sent if the device was signed + devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + + if user_id not in devices: + raise NotFoundError("No device keys found") + + devices = devices[user_id] + except SynapseError as e: + failure = _exception_to_failure(e) + failures[user_id] = {device: failure for device in signatures.keys()} + return signature_list, failures + + for device_id, device in signatures.items(): + # make sure submitted data is in the right form + if not isinstance(device, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + + try: + if "signatures" not in device or user_id not in device["signatures"]: + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + if device_id == master_verify_key.version: + # The signature is of the master key. This needs to be + # handled differently from signatures of normal devices. + master_key_signature_list = self._check_master_key_signature( + user_id, device_id, device, master_key, devices + ) + signature_list.extend(master_key_signature_list) + continue + + # at this point, we have a device that should be signed + # by the self-signing key + if self_signing_key_id not in device["signatures"][user_id]: + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + try: + stored_device = devices[device_id] + except KeyError: + raise NotFoundError("Unknown device") + if self_signing_key_id in stored_device.get("signatures", {}).get( + user_id, {} + ): + # we already have a signature on this device, so we + # can skip it, since it should be exactly the same + continue + + _check_device_signature( + user_id, self_signing_verify_key, device, stored_device + ) + + signature = device["signatures"][user_id][self_signing_key_id] + signature_list.append( + SignatureListItem( + self_signing_key_id, user_id, device_id, signature + ) + ) + except SynapseError as e: + failures.setdefault(user_id, {})[device_id] = _exception_to_failure(e) + + return signature_list, failures + + def _check_master_key_signature( + self, user_id, master_key_id, signed_master_key, stored_master_key, devices + ): + """Check signatures of a user's master key made by their devices. + + Args: + user_id (string): the user whose master key is being checked + master_key_id (string): the ID of the user's master key + signed_master_key (dict): the user's signed master key that was uploaded + stored_master_key (dict): our previously-stored copy of the user's master key + devices (iterable(dict)): the user's devices + + Returns: + list[SignatureListItem]: a list of signatures to store + + Raises: + SynapseError: if a signature is invalid + """ + # for each device that signed the master key, check the signature. + master_key_signature_list = [] + sigs = signed_master_key["signatures"] + for signing_key_id, signature in sigs[user_id].items(): + _, signing_device_id = signing_key_id.split(":", 1) + if ( + signing_device_id not in devices + or signing_key_id not in devices[signing_device_id]["keys"] + ): + # signed by an unknown device, or the + # device does not have the key + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) + + # get the key and check the signature + pubkey = devices[signing_device_id]["keys"][signing_key_id] + verify_key = decode_verify_key_bytes(signing_key_id, decode_base64(pubkey)) + _check_device_signature( + user_id, verify_key, signed_master_key, stored_master_key + ) + + master_key_signature_list.append( + SignatureListItem(signing_key_id, user_id, master_key_id, signature) + ) + + return master_key_signature_list + + @defer.inlineCallbacks + def _process_other_signatures(self, user_id, signatures): + """Process uploaded signatures of other users' keys. These will be the + target user's master keys, signed by the uploading user's user-signing + key. + + Args: + user_id (string): the user uploading the keys + signatures (dict[string, dict]): map of users to devices to signed keys + + Returns: + (list[SignatureListItem], dict[string, dict[string, dict]]): + a list of signatures to store, and a map of users to devices to failure + reasons + + Raises: + SynapseError: if the input is malformed + """ + signature_list = [] + failures = {} + if not signatures: + return signature_list, failures + + try: + # get our user-signing key to verify the signatures + ( + user_signing_key, + user_signing_key_id, + user_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + except SynapseError as e: + failure = _exception_to_failure(e) + for user, devicemap in signatures.items(): + failures[user] = {device_id: failure for device_id in devicemap.keys()} + return signature_list, failures + + for target_user, devicemap in signatures.items(): + # make sure submitted data is in the right form + if not isinstance(devicemap, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + for device in devicemap.values(): + if not isinstance(device, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + + device_id = None + try: + # get the target user's master key, to make sure it matches + # what was sent + ( + master_key, + master_key_id, + _, + ) = yield self._get_e2e_cross_signing_verify_key( + target_user, "master", user_id + ) + + # make sure that the target user's master key is the one that + # was signed (and no others) + device_id = master_key_id.split(":", 1)[1] + if device_id not in devicemap: + logger.debug( + "upload signature: could not find signature for device %s", + device_id, + ) + # set device to None so that the failure gets + # marked on all the signatures + device_id = None + raise NotFoundError("Unknown device") + key = devicemap[device_id] + other_devices = [k for k in devicemap.keys() if k != device_id] + if other_devices: + # other devices were signed -- mark those as failures + logger.debug("upload signature: too many devices specified") + failure = _exception_to_failure(NotFoundError("Unknown device")) + failures[target_user] = { + device: failure for device in other_devices + } + + if user_signing_key_id in master_key.get("signatures", {}).get( + user_id, {} + ): + # we already have the signature, so we can skip it + continue + + _check_device_signature( + user_id, user_signing_verify_key, key, master_key + ) + + signature = key["signatures"][user_id][user_signing_key_id] + signature_list.append( + SignatureListItem( + user_signing_key_id, target_user, device_id, signature + ) + ) + except SynapseError as e: + failure = _exception_to_failure(e) + if device_id is None: + failures[target_user] = { + device_id: failure for device_id in devicemap.keys() + } + else: + failures.setdefault(target_user, {})[device_id] = failure + + return signature_list, failures + + @defer.inlineCallbacks + def _get_e2e_cross_signing_verify_key( + self, user_id: str, key_type: str, from_user_id: str = None + ): + """Fetch locally or remotely query for a cross-signing public key. + + First, attempt to fetch the cross-signing public key from storage. + If that fails, query the keys from the homeserver they belong to + and update our local copy. + + Args: + user_id: the user whose key should be fetched + key_type: the type of key to fetch + from_user_id: the user that we are fetching the keys for. + This affects what signatures are fetched. + + Returns: + dict, str, VerifyKey: the raw key data, the key ID, and the + signedjson verify key + + Raises: + NotFoundError: if the key is not found + SynapseError: if `user_id` is invalid + """ + user = UserID.from_string(user_id) + key = yield self.store.get_e2e_cross_signing_key( + user_id, key_type, from_user_id + ) + + if key: + # We found a copy of this key in our database. Decode and return it + key_id, verify_key = get_verify_key_from_cross_signing_key(key) + return key, key_id, verify_key + + # If we couldn't find the key locally, and we're looking for keys of + # another user then attempt to fetch the missing key from the remote + # user's server. + # + # We may run into this in possible edge cases where a user tries to + # cross-sign a remote user, but does not share any rooms with them yet. + # Thus, we would not have their key list yet. We instead fetch the key, + # store it and notify clients of new, associated device IDs. + if self.is_mine(user) or key_type not in ["master", "self_signing"]: + # Note that master and self_signing keys are the only cross-signing keys we + # can request over federation + raise NotFoundError("No %s key found for %s" % (key_type, user_id)) + + ( + key, + key_id, + verify_key, + ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + + if key is None: + raise NotFoundError("No %s key found for %s" % (key_type, user_id)) + + return key, key_id, verify_key + + @defer.inlineCallbacks + def _retrieve_cross_signing_keys_for_remote_user( + self, user: UserID, desired_key_type: str, + ): + """Queries cross-signing keys for a remote user and saves them to the database + + Only the key specified by `key_type` will be returned, while all retrieved keys + will be saved regardless + + Args: + user: The user to query remote keys for + desired_key_type: The type of key to receive. One of "master", "self_signing" + + Returns: + Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple + of the retrieved key content, the key's ID and the matching VerifyKey. + If the key cannot be retrieved, all values in the tuple will instead be None. + """ + try: + remote_result = yield self.federation.query_user_devices( + user.domain, user.to_string() + ) + except Exception as e: + logger.warning( + "Unable to query %s for cross-signing keys of user %s: %s %s", + user.domain, + user.to_string(), + type(e), + e, + ) + return None, None, None + + # Process each of the retrieved cross-signing keys + desired_key = None + desired_key_id = None + desired_verify_key = None + retrieved_device_ids = [] + for key_type in ["master", "self_signing"]: + key_content = remote_result.get(key_type + "_key") + if not key_content: + continue + + # Ensure these keys belong to the correct user + if "user_id" not in key_content: + logger.warning( + "Invalid %s key retrieved, missing user_id field: %s", + key_type, + key_content, + ) + continue + if user.to_string() != key_content["user_id"]: + logger.warning( + "Found %s key of user %s when querying for keys of user %s", + key_type, + key_content["user_id"], + user.to_string(), + ) + continue + + # Validate the key contents + try: + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + key_id, verify_key = get_verify_key_from_cross_signing_key(key_content) + except ValueError as e: + logger.warning( + "Invalid %s key retrieved: %s - %s %s", + key_type, + key_content, + type(e), + e, + ) + continue + + # Note down the device ID attached to this key + retrieved_device_ids.append(verify_key.version) + + # If this is the desired key type, save it and its ID/VerifyKey + if key_type == desired_key_type: + desired_key = key_content + desired_verify_key = verify_key + desired_key_id = key_id + + # At the same time, store this key in the db for subsequent queries + yield self.store.set_e2e_cross_signing_key( + user.to_string(), key_type, key_content + ) + + # Notify clients that new devices for this user have been discovered + if retrieved_device_ids: + # XXX is this necessary? + yield self.device_handler.notify_device_update( + user.to_string(), retrieved_device_ids + ) + + return desired_key, desired_key_id, desired_verify_key + + +def _check_cross_signing_key(key, user_id, key_type, signing_key=None): + """Check a cross-signing key uploaded by a user. Performs some basic sanity + checking, and ensures that it is signed, if a signature is required. + + Args: + key (dict): the key data to verify + user_id (str): the user whose key is being checked + key_type (str): the type of key that the key should be + signing_key (VerifyKey): (optional) the signing key that the key should + be signed with. If omitted, signatures will not be checked. + """ + if ( + key.get("user_id") != user_id + or key_type not in key.get("usage", []) + or len(key.get("keys", {})) != 1 + ): + raise SynapseError(400, ("Invalid %s key" % (key_type,)), Codes.INVALID_PARAM) + + if signing_key: + try: + verify_signed_json(key, user_id, signing_key) + except SignatureVerifyException: + raise SynapseError( + 400, ("Invalid signature on %s key" % key_type), Codes.INVALID_SIGNATURE + ) + + +def _check_device_signature(user_id, verify_key, signed_device, stored_device): + """Check that a signature on a device or cross-signing key is correct and + matches the copy of the device/key that we have stored. Throws an + exception if an error is detected. + + Args: + user_id (str): the user ID whose signature is being checked + verify_key (VerifyKey): the key to verify the device with + signed_device (dict): the uploaded signed device data + stored_device (dict): our previously stored copy of the device + + Raises: + SynapseError: if the signature was invalid or the sent device is not the + same as the stored device + + """ + + # make sure that the device submitted matches what we have stored + stripped_signed_device = { + k: v for k, v in signed_device.items() if k not in ["signatures", "unsigned"] + } + stripped_stored_device = { + k: v for k, v in stored_device.items() if k not in ["signatures", "unsigned"] + } + if stripped_signed_device != stripped_stored_device: + logger.debug( + "upload signatures: key does not match %s vs %s", + signed_device, + stored_device, + ) + raise SynapseError(400, "Key does not match") + + try: + verify_signed_json(signed_device, user_id, verify_key) + except SignatureVerifyException: + logger.debug("invalid signature on key") + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) + def _exception_to_failure(e): + if isinstance(e, SynapseError): + return {"status": e.code, "errcode": e.errcode, "message": str(e)} + if isinstance(e, CodeMessageException): return {"status": e.code, "message": str(e)} @@ -476,3 +1214,99 @@ def _one_time_keys_match(old_key_json, new_key): new_key_copy.pop("signatures", None) return old_key == new_key_copy + + +@attr.s +class SignatureListItem: + """An item in the signature list as used by upload_signatures_for_device_keys. + """ + + signing_key_id = attr.ib() + target_user_id = attr.ib() + target_device_id = attr.ib() + signature = attr.ib() + + +class SigningKeyEduUpdater(object): + """Handles incoming signing key updates from federation and updates the DB""" + + def __init__(self, hs, e2e_keys_handler): + self.store = hs.get_datastore() + self.federation = hs.get_federation_client() + self.clock = hs.get_clock() + self.e2e_keys_handler = e2e_keys_handler + + self._remote_edu_linearizer = Linearizer(name="remote_signing_key") + + # user_id -> list of updates waiting to be handled. + self._pending_updates = {} + + # Recently seen stream ids. We don't bother keeping these in the DB, + # but they're useful to have them about to reduce the number of spurious + # resyncs. + self._seen_updates = ExpiringCache( + cache_name="signing_key_update_edu", + clock=self.clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + iterable=True, + ) + + @defer.inlineCallbacks + def incoming_signing_key_update(self, origin, edu_content): + """Called on incoming signing key update from federation. Responsible for + parsing the EDU and adding to pending updates list. + + Args: + origin (string): the server that sent the EDU + edu_content (dict): the contents of the EDU + """ + + user_id = edu_content.pop("user_id") + master_key = edu_content.pop("master_key", None) + self_signing_key = edu_content.pop("self_signing_key", None) + + if get_domain_from_id(user_id) != origin: + logger.warning("Got signing key update edu for %r from %r", user_id, origin) + return + + room_ids = yield self.store.get_rooms_for_user(user_id) + if not room_ids: + # We don't share any rooms with this user. Ignore update, as we + # probably won't get any further updates. + return + + self._pending_updates.setdefault(user_id, []).append( + (master_key, self_signing_key) + ) + + yield self._handle_signing_key_updates(user_id) + + @defer.inlineCallbacks + def _handle_signing_key_updates(self, user_id): + """Actually handle pending updates. + + Args: + user_id (string): the user whose updates we are processing + """ + + device_handler = self.e2e_keys_handler.device_handler + device_list_updater = device_handler.device_list_updater + + with (yield self._remote_edu_linearizer.queue(user_id)): + pending_updates = self._pending_updates.pop(user_id, []) + if not pending_updates: + # This can happen since we batch updates + return + + device_ids = [] + + logger.info("pending updates: %r", pending_updates) + + for master_key, self_signing_key in pending_updates: + new_device_ids = yield device_list_updater.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + new_device_ids + + yield device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index a9d80f708c..9abaf13b8f 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object): rooms session_id(string): session ID to delete keys for, for None to delete keys for all sessions + Raises: + NotFoundError: if the backup version does not exist Returns: - A deferred of the deletion transaction + A dict containing the count and etag for the backup version """ # lock for consistency with uploading with (yield self._upload_linearizer.queue(user_id)): + # make sure the backup version exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version + ) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown backup version") + else: + raise + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + version_etag = version_info["etag"] + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag + ) + + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} + @trace @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): @@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object): } } + Returns: + A dict containing the count and etag for the backup version + Raises: NotFoundError: if there are no versions defined RoomKeysVersionError: if the uploaded version is not the current version @@ -171,59 +196,69 @@ class E2eRoomKeysHandler(object): else: raise - # go through the room_keys. - # XXX: this should/could be done concurrently, given we're in a lock. + # Fetch any existing room keys for the sessions that have been + # submitted. Then compare them with the submitted keys. If the + # key is new, insert it; if the key should be updated, then update + # it; otherwise, drop it. + existing_keys = yield self.store.get_e2e_room_keys_multi( + user_id, version, room_keys["rooms"] + ) + to_insert = [] # batch the inserts together + changed = False # if anything has changed, we need to update the etag for room_id, room in iteritems(room_keys["rooms"]): - for session_id, session in iteritems(room["sessions"]): - yield self._upload_room_key( - user_id, version, room_id, session_id, session + for session_id, room_key in iteritems(room["sessions"]): + if not isinstance(room_key["is_verified"], bool): + msg = ( + "is_verified must be a boolean in keys for session %s in" + "room %s" % (session_id, room_id) + ) + raise SynapseError(400, msg, Codes.INVALID_PARAM) + + log_kv( + { + "message": "Trying to upload room key", + "room_id": room_id, + "session_id": session_id, + "user_id": user_id, + } ) - - @defer.inlineCallbacks - def _upload_room_key(self, user_id, version, room_id, session_id, room_key): - """Upload a given room_key for a given room and session into a given - version of the backup. Merges the key with any which might already exist. - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - """ - log_kv( - { - "message": "Trying to upload room key", - "room_id": room_id, - "session_id": session_id, - "user_id": user_id, - } - ) - # get the room_key for this particular row - current_room_key = None - try: - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id - ) - except StoreError as e: - if e.code == 404: - log_kv( - { - "message": "Room key not found.", - "room_id": room_id, - "user_id": user_id, - } + current_room_key = existing_keys.get(room_id, {}).get(session_id) + if current_room_key: + if self._should_replace_room_key(current_room_key, room_key): + log_kv({"message": "Replacing room key."}) + # updates are done one at a time in the DB, so send + # updates right away rather than batching them up, + # like we do with the inserts + yield self.store.update_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + changed = True + else: + log_kv({"message": "Not replacing room_key."}) + else: + log_kv( + { + "message": "Room key not found.", + "room_id": room_id, + "user_id": user_id, + } + ) + log_kv({"message": "Replacing room key."}) + to_insert.append((room_id, session_id, room_key)) + changed = True + + if len(to_insert): + yield self.store.add_e2e_room_keys(user_id, version, to_insert) + + version_etag = version_info["etag"] + if changed: + version_etag = version_etag + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag ) - else: - raise - if self._should_replace_room_key(current_room_key, room_key): - log_kv({"message": "Replacing room key."}) - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) - else: - log_kv({"message": "Not replacing room_key."}) + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} @staticmethod def _should_replace_room_key(current_room_key, room_key): @@ -314,6 +349,8 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise + + res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) return res @trace @@ -352,8 +389,8 @@ class E2eRoomKeysHandler(object): A deferred of an empty dict. """ if "version" not in version_info: - raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM) - if version_info["version"] != version: + version_info["version"] = version + elif version_info["version"] != version: raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5e748687e3..71a89f09c7 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -16,11 +16,10 @@ import logging import random -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase +from synapse.handlers.presence import format_user_presence_state from synapse.logging.utils import log_function from synapse.types import UserID from synapse.visibility import filter_events_for_client @@ -50,9 +49,8 @@ class EventStreamHandler(BaseHandler): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks @log_function - def get_stream( + async def get_stream( self, auth_user_id, pagin_config, @@ -69,17 +67,17 @@ class EventStreamHandler(BaseHandler): """ if room_id: - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(auth_user_id) auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() - context = yield presence_handler.user_syncing( + context = await presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence ) with context: @@ -91,7 +89,7 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) - events, tokens = yield self.notifier.get_events_for( + events, tokens = await self.notifier.get_events_for( auth_user, pagin_config, timeout, @@ -100,6 +98,8 @@ class EventStreamHandler(BaseHandler): explicit_room_id=room_id, ) + time_now = self.clock.time_msec() + # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. to_add = [] @@ -112,23 +112,24 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room( + users = await self.state.get_current_users_in_room( event.room_id ) - states = yield presence_handler.get_states(users, as_event=True) - to_add.extend(states) else: + users = [event.state_key] - ev = yield presence_handler.get_state( - UserID.from_string(event.state_key), as_event=True - ) - to_add.append(ev) + states = await presence_handler.get_states(users) + to_add.extend( + { + "type": EventTypes.Presence, + "content": format_user_presence_state(state, time_now), + } + for state in states + ) events.extend(to_add) - time_now = self.clock.time_msec() - - chunks = yield self._event_serializer.serialize_events( + chunks = await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, @@ -147,8 +148,11 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): - @defer.inlineCallbacks - def get_event(self, user, room_id, event_id): + def __init__(self, hs): + super(EventHandler, self).__init__(hs) + self.storage = hs.get_storage() + + async def get_event(self, user, room_id, event_id): """Retrieve a single specified event. Args: @@ -163,16 +167,16 @@ class EventHandler(BaseHandler): AuthError if the user does not have the rights to inspect this event. """ - event = yield self.store.get_event(event_id, check_room_id=room_id) + event = await self.store.get_event(event_id, check_room_id=room_id) if not event: return None - users = yield self.store.get_users_in_room(event.room_id) + users = await self.store.get_users_in_room(event.room_id) is_peeking = user.to_string() not in users - filtered = yield filter_events_for_client( - self.store, user.to_string(), [event], is_peeking=is_peeking + filtered = await filter_events_for_client( + self.storage, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f72b81d419..3e60774b33 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,17 +19,20 @@ import itertools import logging +from typing import Dict, Iterable, List, Optional, Sequence, Tuple import six from six import iteritems, itervalues from six.moves import http_client, zip +import attr from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 from twisted.internet import defer +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.errors import ( AuthError, @@ -37,14 +40,17 @@ from synapse.api.errors import ( Codes, FederationDeniedError, FederationError, + HttpResponseException, RequestSendFailed, - StoreError, SynapseError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -52,52 +58,49 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.utils import log_function +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, + ReplicationStoreRoomOnInviteRestServlet, ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store -from synapse.types import UserID, get_domain_from_id -from synapse.util import unwrapFirstError -from synapse.util.async_helpers import Linearizer +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server -from ._base import BaseHandler - logger = logging.getLogger(__name__) -def shortstr(iterable, maxitems=5): - """If iterable has maxitems or fewer, return the stringification of a list - containing those items. +@attr.s +class _NewEventInfo: + """Holds information about a received event, ready for passing to _handle_new_events - Otherwise, return the stringification of a a list with the first maxitems items, - followed by "...". + Attributes: + event: the received event - Args: - iterable (Iterable): iterable to truncate - maxitems (int): number of items to return before truncating + state: the state at that event - Returns: - unicode + auth_events: the auth_event map for that event """ - items = list(itertools.islice(iterable, maxitems + 1)) - if len(items) <= maxitems: - return str(items) - return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" + event = attr.ib(type=EventBase) + state = attr.ib(type=Optional[Sequence[EventBase]], default=None) + auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: a) handling received Pdus before handing them on as Events to the rest - of the home server (including auth and state conflict resoultion) + of the homeserver (including auth and state conflict resoultion) b) converting events that were produced by local clients that may need - to be sent to remote home servers. + to be sent to remote homeservers. c) doing the necessary dances to invite remote users and join remote rooms. """ @@ -108,6 +111,8 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -117,13 +122,14 @@ class FederationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._instance_name = hs.get_instance_name() + self._replication = hs.get_replication_data_handler() - self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( - hs - ) + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( hs ) @@ -131,14 +137,26 @@ class FederationHandler(BaseHandler): hs ) + if hs.config.worker_app: + self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( + hs + ) + self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( + hs + ) + else: + self._device_list_updater = hs.get_device_handler().device_list_updater + self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite + # When joining a room we need to queue any events for that room up self.room_queues = {} self._room_pdu_linearizer = Linearizer("fed_room_pdu") self.third_party_event_rules = hs.get_third_party_event_rules() - @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -148,17 +166,15 @@ class FederationHandler(BaseHandler): pdu (FrozenEvent): received PDU sent_to_us_directly (bool): True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. - - Returns (Deferred): completes with None """ room_id = pdu.room_id event_id = pdu.event_id - logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) + logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers - existing = yield self.store.get_event( + existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -179,7 +195,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn( + logger.warning( "[%s %s] Received event failed sanity checks", room_id, event_id ) raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) @@ -202,7 +218,7 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", @@ -213,25 +229,24 @@ class FederationHandler(BaseHandler): return None state = None - auth_chain = [] # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context(pdu.room_id) + min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) - if min_depth and pdu.depth < min_depth: + if min_depth is not None and pdu.depth < min_depth: # This is so that we don't notify the user about this # message, to work around the fact that some events will # reference really really old events we really don't want to # send to the clients. pdu.internal_metadata.outlier = True - elif min_depth and pdu.depth > min_depth: + elif min_depth is not None and pdu.depth > min_depth: missing_prevs = prevs - seen if sent_to_us_directly and missing_prevs: # If we're missing stuff, ensure we only fetch stuff one @@ -243,7 +258,7 @@ class FederationHandler(BaseHandler): len(missing_prevs), shortstr(missing_prevs), ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", room_id, @@ -251,13 +266,19 @@ class FederationHandler(BaseHandler): len(missing_prevs), ) - yield self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) # Update the set of things we've seen after trying to # fetch the missing stuff - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -265,14 +286,6 @@ class FederationHandler(BaseHandler): room_id, event_id, ) - elif missing_prevs: - logger.info( - "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, - event_id, - len(missing_prevs), - shortstr(missing_prevs), - ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -300,7 +313,7 @@ class FederationHandler(BaseHandler): # following. if sent_to_us_directly: - logger.warn( + logger.warning( "[%s %s] Rejecting: failed to fetch %d prev events: %s", room_id, event_id, @@ -317,18 +330,21 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) + logger.info( + "Event %s is missing prev_events: calculating state for a " + "backwards extremity", + event_id, + ) + # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. - auth_chains = set() event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.store.get_state_groups_ids(room_id, seen) + ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list( - ours.values() - ) # type: list[dict[tuple[str, str], str]] + state_maps = list(ours.values()) # type: List[StateMap[str]] # we don't need this any more, let's delete it. del ours @@ -337,43 +353,17 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "[%s %s] Requesting state at missing prev_event %s", - room_id, - event_id, - p, + "Requesting state at missing prev_event %s", event_id, ) - room_version = yield self.store.get_room_version(room_id) - with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + (remote_state, _,) = await self._get_state_for_room( + origin, room_id, p, include_event_in_state=True ) - # we want the state *after* p; get_state_for_room returns the - # state *before* p. - remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True - ) - - if remote_event is None: - raise Exception( - "Unable to get missing prev_event %s" % (p,) - ) - - if remote_event.is_state(): - remote_state.append(remote_event) - - # XXX hrm I'm not convinced that duplicate events will compare - # for equality, so I'm not sure this does what the author - # hoped. - auth_chains.update(got_auth_chain) - remote_state_map = { (x.type, x.state_key): x.event_id for x in remote_state } @@ -382,7 +372,9 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - state_map = yield resolve_events_with_store( + room_version = await self.store.get_room_version_id(room_id) + state_map = await resolve_events_with_store( + room_id, room_version, state_maps, event_map, @@ -394,17 +386,16 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = yield self.store.get_events( + evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, ) event_map.update(evs) state = [event_map[e] for e in six.itervalues(state_map)] - auth_chain = list(auth_chains) except Exception: - logger.warn( + logger.warning( "[%s %s] Error attempting to resolve state at missing " "prev_events", room_id, @@ -418,12 +409,9 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu( - origin, pdu, state=state, auth_chain=auth_chain - ) + await self._process_received_pdu(origin, pdu, state=state) - @defer.inlineCallbacks - def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ Args: origin (str): Origin of the pdu. Will be called to get the missing events @@ -435,12 +423,12 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room(room_id) + latest = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -504,7 +492,7 @@ class FederationHandler(BaseHandler): # All that said: Let's try increasing the timout to 60s and see what happens. try: - missing_events = yield self.federation_client.get_missing_events( + missing_events = await self.federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -513,11 +501,13 @@ class FederationHandler(BaseHandler): min_depth=min_depth, timeout=60000, ) - except RequestSendFailed as e: + except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) + logger.warning( + "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e + ) return logger.info( @@ -541,10 +531,10 @@ class FederationHandler(BaseHandler): ) with nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: - logger.warn( + logger.warning( "[%s %s] Received prev_event %s failed history check.", room_id, event_id, @@ -553,66 +543,154 @@ class FederationHandler(BaseHandler): else: raise - @defer.inlineCallbacks - def _process_received_pdu(self, origin, event, state, auth_chain): - """ Called when we have a new pdu. We need to do auth checks and put it - through the StateHandler. + async def _get_state_for_room( + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, + ) -> Tuple[List[EventBase], List[EventBase]]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. + + Returns: + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. """ - room_id = event.room_id - event_id = event.event_id + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) - logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) + desired_events = set(state_event_ids + auth_event_ids) - event_ids = set() - if state: - event_ids |= {e.event_id for e in state} - if auth_chain: - event_ids |= {e.event_id for e in auth_chain} + if include_event_in_state: + desired_events.add(event_id) - seen_ids = yield self.store.have_seen_events(event_ids) + event_map = await self._get_events_from_store_or_dest( + destination, room_id, desired_events + ) - if state and auth_chain is not None: - # If we have any state or auth_chain given to us by the replication - # layer, then we should handle them (if we haven't before.) + failed_to_fetch = desired_events - event_map.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s %s", + event_id, + failed_to_fetch, + ) - event_infos = [] + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] - for e in itertools.chain(auth_chain, state): - if e.event_id in seen_ids: - continue - e.internal_metadata.outlier = True - auth_ids = e.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - event_infos.append({"event": e, "auth_events": auth}) - seen_ids.add(e.event_id) + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) - logger.info( - "[%s %s] persisting newly-received auth/state events %s", + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + auth_chain.sort(key=lambda e: e.depth) + + return remote_state, auth_chain + + async def _get_events_from_store_or_dest( + self, destination: str, room_id: str, event_ids: Iterable[str] + ) -> Dict[str, EventBase]: + """Fetch events from a remote destination, checking if we already have them. + + Persists any events we don't already have as outliers. + + If we fail to fetch any of the events, a warning will be logged, and the event + will be omitted from the result. Likewise, any events which turn out not to + be in the given room. + + Returns: + map from event_id to event + """ + fetched_events = await self.store.get_events(event_ids, allow_rejected=True) + + missing_events = set(event_ids) - fetched_events.keys() + + if missing_events: + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, room_id, - event_id, - [e["event"].event_id for e in event_infos], ) - yield self._handle_new_events(origin, event_infos) + + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (await self.store.get_events(missing_events, allow_rejected=True)) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned auth/state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + return fetched_events + + async def _process_received_pdu( + self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + ): + """ Called when we have a new pdu. We need to do auth checks and put it + through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event + """ + room_id = event.room_id + event_id = event.event_id + + logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) try: - context = yield self._handle_new_event(origin, event, state=state) + context = await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - room = yield self.store.get_room(room_id) - - if not room: - try: - yield self.store.store_room( - room_id=room_id, room_creator_user_id="", is_public=False - ) - except StoreError: - logger.exception("Failed to store room.") - if event.type == EventTypes.Member: if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has acutally @@ -620,11 +698,11 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: - prev_state = yield self.store.get_event( + prev_state = await self.store.get_event( prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: @@ -632,11 +710,82 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, room_id) + await self.user_joined_room(user, room_id) + + # For encrypted messages we check that we know about the sending device, + # if we don't then we mark the device cache for that user as stale. + if event.type == EventTypes.Encrypted: + device_id = event.content.get("device_id") + sender_key = event.content.get("sender_key") + + cached_devices = await self.store.get_cached_devices_for_user(event.sender) + + resync = False # Whether we should resync device lists. + + device = None + if device_id is not None: + device = cached_devices.get(device_id) + if device is None: + logger.info( + "Received event from remote device not in our cache: %s %s", + event.sender, + device_id, + ) + resync = True + + # We also check if the `sender_key` matches what we expect. + if sender_key is not None: + # Figure out what sender key we're expecting. If we know the + # device and recognize the algorithm then we can work out the + # exact key to expect. Otherwise check it matches any key we + # have for that device. + if device: + keys = device.get("keys", {}).get("keys", {}) + + if event.content.get("algorithm") == "m.megolm.v1.aes-sha2": + # For this algorithm we expect a curve25519 key. + key_name = "curve25519:%s" % (device_id,) + current_keys = [keys.get(key_name)] + else: + # We don't know understand the algorithm, so we just + # check it matches a key for the device. + current_keys = keys.values() + elif device_id: + # We don't have any keys for the device ID. + current_keys = [] + else: + # The event didn't include a device ID, so we just look for + # keys across all devices. + current_keys = ( + key + for device in cached_devices + for key in device.get("keys", {}).get("keys", {}).values() + ) + + # We now check that the sender key matches (one of) the expected + # keys. + if sender_key not in current_keys: + logger.info( + "Received event from remote device with unexpected sender key: %s %s: %s", + event.sender, + device_id or "<no device_id>", + sender_key, + ) + resync = True + + if resync: + await self.store.mark_remote_user_device_cache_as_stale(event.sender) + + # Immediately attempt a resync in the background + if self.config.worker_app: + return run_in_background(self._user_device_resync, event.sender) + else: + return run_in_background( + self._device_list_updater.user_device_resync, event.sender + ) @log_function - @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities): + async def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -653,9 +802,7 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - room_version = yield self.store.get_room_version(room_id) - - events = yield self.federation_client.backfill( + events = await self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -670,8 +817,8 @@ class FederationHandler(BaseHandler): # self._sanity_check_event(ev) # Don't bother processing events we already have. - seen_events = yield self.store.have_events_in_timeline( - set(e.event_id for e in events) + seen_events = await self.store.have_events_in_timeline( + {e.event_id for e in events} ) events = [e for e in events if e.event_id not in seen_events] @@ -681,8 +828,11 @@ class FederationHandler(BaseHandler): event_map = {e.event_id: e for e in events} - event_ids = set(e.event_id for e in events) + event_ids = {e.event_id for e in events} + # build a list of events whose prev_events weren't in the batch. + # (XXX: this will include events whose prev_events we already have; that doesn't + # sound right?) edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) @@ -693,113 +843,32 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.federation_client.get_state_for_room( - destination=dest, room_id=room_id, event_id=e_id + state, auth = await self._get_state_for_room( + destination=dest, + room_id=room_id, + event_id=e_id, + include_event_in_state=False, ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - required_auth = set( + required_auth = { a_id for event in events + list(state_events.values()) + list(auth_events.values()) for a_id in event.auth_event_ids() - ) + } auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) - missing_auth = required_auth - set(auth_events) - failed_to_fetch = set() - - # Try and fetch any missing auth events from both DB and remote servers. - # We repeatedly do this until we stop finding new auth events. - while missing_auth - failed_to_fetch: - logger.info("Missing auth for backfill: %r", missing_auth) - ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) - auth_events.update(ret_events) - - required_auth.update( - a_id for event in ret_events.values() for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - - if missing_auth - failed_to_fetch: - logger.info( - "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch, - ) - - results = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results if a}) - required_auth.update( - a_id - for event in results - if event - for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - - failed_to_fetch = missing_auth - set(auth_events) - - seen_events = yield self.store.have_seen_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - # We now have a chunk of events plus associated state and auth chain to - # persist. We do the persistence in two steps: - # 1. Auth events and state get persisted as outliers, plus the - # backward extremities get persisted (as non-outliers). - # 2. The rest of the events in the chunk get persisted one by one, as - # each one depends on the previous event for its state. - # - # The important thing is that events in the chunk get persisted as - # non-outliers, including when those events are also in the state or - # auth chain. Caution must therefore be taken to ensure that they are - # not accidentally marked as outliers. - # Step 1a: persist auth events that *don't* appear in the chunk ev_infos = [] - for a in auth_events.values(): - # We only want to persist auth events as outliers that we haven't - # seen and aren't about to persist as part of the backfilled chunk. - if a.event_id in seen_events or a.event_id in event_map: - continue - - a.internal_metadata.outlier = True - ev_infos.append( - { - "event": a, - "auth_events": { - ( - auth_events[a_id].type, - auth_events[a_id].state_key, - ): auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events - }, - } - ) - # Step 1b: persist the events in the chunk we fetched state for (i.e. - # the backwards extremities) as non-outliers. + # Step 1: persist the events in the chunk we fetched state for (i.e. + # the backwards extremities), with custom auth events and state for e_id in events_to_state: # For paranoia we ensure that these events are marked as # non-outliers @@ -807,10 +876,10 @@ class FederationHandler(BaseHandler): assert not ev.internal_metadata.is_outlier() ev_infos.append( - { - "event": ev, - "state": events_to_state[e_id], - "auth_events": { + _NewEventInfo( + event=ev, + state=events_to_state[e_id], + auth_events={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -818,10 +887,10 @@ class FederationHandler(BaseHandler): for a_id in ev.auth_event_ids() if a_id in auth_events }, - } + ) ) - yield self._handle_new_events(dest, ev_infos, backfilled=True) + await self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -837,16 +906,15 @@ class FederationHandler(BaseHandler): # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event(dest, event, backfilled=True) + await self._handle_new_event(dest, event, backfilled=True) return events - @defer.inlineCallbacks - def maybe_backfill(self, room_id, current_depth): + async def maybe_backfill(self, room_id, current_depth): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) + extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -878,16 +946,18 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events(list(extremities)) + forward_events = await self.store.get_successor_events(list(extremities)) - extremities_events = yield self.store.get_events( - forward_events, check_redacted=False, get_prev_content=False + extremities_events = await self.store.get_events( + forward_events, + redact_behaviour=EventRedactBehaviour.AS_IS, + get_prev_content=False, ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. - filtered_extremities = yield filter_events_for_server( - self.store, + filtered_extremities = await filter_events_for_server( + self.storage, self.server_name, list(extremities_events.values()), redact=False, @@ -916,7 +986,7 @@ class FederationHandler(BaseHandler): # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = yield self.state_handler.get_current_state(room_id) + curr_state = await self.state_handler.get_current_state(room_id) def get_domains_from_state(state): """Get joined domains from state @@ -955,12 +1025,11 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - @defer.inlineCallbacks - def try_backfill(domains): + async def try_backfill(domains): # TODO: Should we try multiple of these at a time? for dom in domains: try: - yield self.backfill( + await self.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -970,6 +1039,12 @@ class FederationHandler(BaseHandler): except SynapseError as e: logger.info("Failed to backfill from %s because %s", dom, e) continue + except HttpResponseException as e: + if 400 <= e.code < 500: + raise e.to_synapse_error() + + logger.info("Failed to backfill from %s because %s", dom, e) + continue except CodeMessageException as e: if 400 <= e.code < 500: raise @@ -991,7 +1066,7 @@ class FederationHandler(BaseHandler): return False - success = yield try_backfill(likely_domains) + success = await try_backfill(likely_domains) if success: return True @@ -1005,7 +1080,7 @@ class FederationHandler(BaseHandler): logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = yield make_deferred_yieldable( + states = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) @@ -1015,7 +1090,7 @@ class FederationHandler(BaseHandler): # event_ids. states = dict(zip(event_ids, [s.state for s in states])) - state_map = yield self.store.get_events( + state_map = await self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False, ) @@ -1031,7 +1106,7 @@ class FederationHandler(BaseHandler): for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill( + success = await try_backfill( [dom for dom, _ in likely_domains if dom not in tried_domains] ) if success: @@ -1041,6 +1116,56 @@ class FederationHandler(BaseHandler): return False + async def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ): + """Fetch the given events from a server, and persist them as outliers. + + Logs a warning if we can't find the given event. + """ + + room_version = await self.store.get_room_version(room_id) + + event_infos = [] + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], event_id, room_version, outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", destination, event_id, + ) + return + + # recursively fetch the auth events for this event + auth_events = await self._get_events_from_store_or_dest( + destination, room_id, event.auth_event_ids() + ) + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = auth_events.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + + event_infos.append(_NewEventInfo(event, None, auth)) + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + await concurrently_execute(get_event, events, 5) + + await self._handle_new_events( + destination, event_infos, + ) + def _sanity_check_event(self, ev): """ Do some early sanity checks of a received event @@ -1058,7 +1183,7 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn( + logger.warning( "Rejecting event %s which has %i prev_events", ev.event_id, len(ev.prev_event_ids()), @@ -1066,20 +1191,19 @@ class FederationHandler(BaseHandler): raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn( + logger.warning( "Rejecting event %s which has %i auth_events", ev.event_id, len(ev.auth_event_ids()), ) raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") - @defer.inlineCallbacks - def send_invite(self, target_host, event): + async def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. """ - pdu = yield self.federation_client.send_invite( + pdu = await self.federation_client.send_invite( destination=target_host, room_id=event.room_id, event_id=event.event_id, @@ -1088,19 +1212,18 @@ class FederationHandler(BaseHandler): return pdu - @defer.inlineCallbacks - def on_event_auth(self, event_id): - event = yield self.store.get_event(event_id) - auth = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + async def on_event_auth(self, event_id: str) -> List[EventBase]: + event = await self.store.get_event(event_id) + auth = await self.store.get_auth_chain( + list(event.auth_event_ids()), include_given=True ) - return [e for e in auth] + return list(auth) - @log_function - @defer.inlineCallbacks - def do_invite_join(self, target_hosts, room_id, joinee, content): + async def do_invite_join( + self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict + ) -> Tuple[str, int]: """ Attempts to join the `joinee` to the room `room_id` via the - server `target_host`. + servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial event that we can fill out and sign. This is then sent to the @@ -1109,10 +1232,23 @@ class FederationHandler(BaseHandler): We suspend processing of any received events from this room until we have finished processing the join. + + Args: + target_hosts: List of servers to attempt to join the room with. + + room_id: The ID of the room to join. + + joinee: The User ID of the joining user. + + content: The event content to use for the join event. """ + # TODO: We should be able to call this on workers, but the upgrading of + # room stuff after join currently doesn't work on workers. + assert self.config.worker.worker_app is None + logger.debug("Joining %s to %s", joinee, room_id) - origin, event, event_format_version = yield self._make_and_verify_event( + origin, event, room_version_obj = await self._make_and_verify_event( target_hosts, room_id, joinee, @@ -1128,7 +1264,7 @@ class FederationHandler(BaseHandler): self.room_queues[room_id] = [] - yield self._clean_room_for_join(room_id) + await self._clean_room_for_join(room_id) handled_events = set() @@ -1140,8 +1276,9 @@ class FederationHandler(BaseHandler): target_hosts.insert(0, origin) except ValueError: pass - ret = yield self.federation_client.send_join( - target_hosts, event, event_format_version + + ret = await self.federation_client.send_join( + target_hosts, event, room_version_obj ) origin = ret["origin"] @@ -1158,17 +1295,49 @@ class FederationHandler(BaseHandler): logger.debug("do_invite_join event: %s", event) - try: - yield self.store.store_room( - room_id=room_id, room_creator_user_id="", is_public=False - ) - except Exception: - # FIXME - pass + # if this is the first time we've joined this room, it's time to add + # a row to `rooms` with the correct room version. If there's already a + # row there, we should override it, since it may have been populated + # based on an invite request which lied about the room version. + # + # federation_client.send_join has already checked that the room + # version in the received create event is the same as room_version_obj, + # so we can rely on it now. + # + await self.store.upsert_room_on_join( + room_id=room_id, room_version=room_version_obj, + ) - yield self._persist_auth_tree(origin, auth_chain, state, event) + max_stream_id = await self._persist_auth_tree( + origin, auth_chain, state, event, room_version_obj + ) + + # We wait here until this instance has seen the events come down + # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + self.config.worker.writers.events, "events", max_stream_id + ) + + # Check whether this room is the result of an upgrade of a room we already know + # about. If so, migrate over user information + predecessor = await self.store.get_room_predecessor(room_id) + if not predecessor or not isinstance(predecessor.get("room_id"), str): + return event.event_id, max_stream_id + old_room_id = predecessor["room_id"] + logger.debug( + "Found predecessor for %s during remote join: %s", room_id, old_room_id + ) + + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + await member_handler.transfer_room_state_on_room_upgrade( + old_room_id, room_id + ) logger.debug("Finished joining %s to %s", joinee, room_id) + return event.event_id, max_stream_id finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] @@ -1181,10 +1350,7 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) - return True - - @defer.inlineCallbacks - def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. Args: @@ -1200,28 +1366,24 @@ class FederationHandler(BaseHandler): p.room_id, ) with nested_logging_context(p.event_id): - yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: - logger.warn( + logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e ) - @defer.inlineCallbacks - @log_function - def on_make_join_request(self, origin, room_id, user_id): + async def on_make_join_request( + self, origin: str, room_id: str, user_id: str + ) -> EventBase: """ We've received a /make_join/ request, so we create a partial join event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. Args: - origin (str): The (verified) server name of the requesting server. - room_id (str): Room to create join event in - user_id (str): The user to create the join for - - Returns: - Deferred[FrozenEvent] + origin: The (verified) server name of the requesting server. + room_id: Room to create join event in + user_id: The user to create the join for """ - if get_domain_from_id(user_id) != origin: logger.info( "Got /make_join request for user %r from different origin %s, ignoring", @@ -1232,7 +1394,7 @@ class FederationHandler(BaseHandler): event_content = {"membership": Membership.JOIN} - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new( room_version, @@ -1246,14 +1408,14 @@ class FederationHandler(BaseHandler): ) try: - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) except AuthError as e: - logger.warn("Failed to create join %r because %s", event, e) + logger.warning("Failed to create join to %s because %s", room_id, e) raise e - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -1264,26 +1426,33 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - yield self.auth.check_from_context( + await self.auth.check_from_context( room_version, event, context, do_sig_check=False ) return event - @defer.inlineCallbacks - @log_function - def on_send_join_request(self, origin, pdu): + async def on_send_join_request(self, origin, pdu): """ We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ event = pdu logger.debug( - "on_send_join_request: Got event: %s, signatures: %s", + "on_send_join_request from %s: Got event: %s, signatures: %s", + origin, event.event_id, event.signatures, ) + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_join request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + event.internal_metadata.outlier = False # Send this event on behalf of the origin server. # @@ -1300,9 +1469,9 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context = yield self._handle_new_event(origin, event) + context = await self._handle_new_event(origin, event) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -1320,29 +1489,28 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, event.room_id) + await self.user_joined_room(user, event.room_id) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) - auth_chain = yield self.store.get_auth_chain(state_ids) + auth_chain = await self.store.get_auth_chain(state_ids) - state = yield self.store.get_events(list(prev_state_ids.values())) + state = await self.store.get_events(list(prev_state_ids.values())) return {"state": list(state.values()), "auth_chain": auth_chain} - @defer.inlineCallbacks - def on_invite_request(self, origin, pdu): + async def on_invite_request( + self, origin: str, event: EventBase, room_version: RoomVersion + ): """ We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. """ - event = pdu - if event.state_key is None: raise SynapseError(400, "The invite event did not have a state key") - is_blocked = yield self.store.is_room_blocked(event.room_id) + is_blocked = await self.store.is_room_blocked(event.room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -1373,24 +1541,35 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + # keep a record of the room version, if we don't yet know it. + # (this may get overwritten if we later get a different room version in a + # join dance). + await self._maybe_store_room_on_invite( + room_id=event.room_id, room_version=room_version + ) + event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True event.signatures.update( compute_event_signature( - event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0] + room_version, + event.get_pdu_json(), + self.hs.hostname, + self.hs.config.signing_key[0], ) ) - context = yield self.state_handler.compute_event_context(event) - yield self.persist_events_and_notify([(event, context)]) + context = await self.state_handler.compute_event_context(event) + await self.persist_events_and_notify([(event, context)]) return event - @defer.inlineCallbacks - def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave" + async def do_remotely_reject_invite( + self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict + ) -> Tuple[EventBase, int]: + origin, event, room_version = await self._make_and_verify_event( + target_hosts, room_id, user_id, "leave", content=content ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1405,18 +1584,27 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.federation_client.send_leave(target_hosts, event) - - context = yield self.state_handler.compute_event_context(event) - yield self.persist_events_and_notify([(event, context)]) - - return event - - @defer.inlineCallbacks - def _make_and_verify_event( - self, target_hosts, room_id, user_id, membership, content={}, params=None - ): - origin, event, format_ver = yield self.federation_client.make_membership_event( + await self.federation_client.send_leave(target_hosts, event) + + context = await self.state_handler.compute_event_context(event) + stream_id = await self.persist_events_and_notify([(event, context)]) + + return event, stream_id + + async def _make_and_verify_event( + self, + target_hosts: Iterable[str], + room_id: str, + user_id: str, + membership: str, + content: JsonDict = {}, + params: Optional[Dict[str, str]] = None, + ) -> Tuple[str, EventBase, RoomVersion]: + ( + origin, + event, + room_version, + ) = await self.federation_client.make_membership_event( target_hosts, room_id, user_id, membership, content, params=params ) @@ -1428,22 +1616,19 @@ class FederationHandler(BaseHandler): assert event.user_id == user_id assert event.state_key == user_id assert event.room_id == room_id - return origin, event, format_ver + return origin, event, room_version - @defer.inlineCallbacks - @log_function - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request( + self, origin: str, room_id: str, user_id: str + ) -> EventBase: """ We've received a /make_leave/ request, so we create a partial leave event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. Args: - origin (str): The (verified) server name of the requesting server. - room_id (str): Room to create leave event in - user_id (str): The user to create the leave for - - Returns: - Deferred[FrozenEvent] + origin: The (verified) server name of the requesting server. + room_id: Room to create leave event in + user_id: The user to create the leave for """ if get_domain_from_id(user_id) != origin: logger.info( @@ -1453,7 +1638,7 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new( room_version, { @@ -1465,11 +1650,11 @@ class FederationHandler(BaseHandler): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -1481,18 +1666,16 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - yield self.auth.check_from_context( + await self.auth.check_from_context( room_version, event, context, do_sig_check=False ) except AuthError as e: - logger.warn("Failed to create new leave %r because %s", event, e) + logger.warning("Failed to create new leave %r because %s", event, e) raise e return event - @defer.inlineCallbacks - @log_function - def on_send_leave_request(self, origin, pdu): + async def on_send_leave_request(self, origin, pdu): """ We have received a leave event for a room. Fully process it.""" event = pdu @@ -1502,11 +1685,19 @@ class FederationHandler(BaseHandler): event.signatures, ) + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /send_leave request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + event.internal_metadata.outlier = False - context = yield self._handle_new_event(origin, event) + context = await self._handle_new_event(origin, event) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -1523,16 +1714,15 @@ class FederationHandler(BaseHandler): return None - @defer.inlineCallbacks - def get_state_for_pdu(self, room_id, event_id): + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event. """ - event = yield self.store.get_event( + event = await self.store.get_event( event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups(room_id, [event_id]) + state_groups = await self.state_store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() @@ -1543,7 +1733,7 @@ class FederationHandler(BaseHandler): if "replaces_state" in event.unsigned: prev_id = event.unsigned["replaces_state"] if prev_id != event.event_id: - prev_event = yield self.store.get_event(prev_id) + prev_event = await self.store.get_event(prev_id) results[(event.type, event.state_key)] = prev_event else: del results[(event.type, event.state_key)] @@ -1553,15 +1743,14 @@ class FederationHandler(BaseHandler): else: return [] - @defer.inlineCallbacks - def get_state_ids_for_pdu(self, room_id, event_id): + async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = yield self.store.get_event( + event = await self.store.get_event( event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) + state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1580,46 +1769,50 @@ class FederationHandler(BaseHandler): else: return [] - @defer.inlineCallbacks @log_function - def on_backfill_request(self, origin, room_id, pdu_list, limit): - in_room = yield self.auth.check_host_in_room(room_id, origin) + async def on_backfill_request( + self, origin: str, room_id: str, pdu_list: List[str], limit: int + ) -> List[EventBase]: + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - events = yield self.store.get_backfill_events(room_id, pdu_list, limit) + # Synapse asks for 100 events per backfill request. Do not allow more. + limit = min(limit, 100) - events = yield filter_events_for_server(self.store, origin, events) + events = await self.store.get_backfill_events(room_id, pdu_list, limit) + + events = await filter_events_for_server(self.storage, origin, events) return events - @defer.inlineCallbacks @log_function - def get_persisted_pdu(self, origin, event_id): + async def get_persisted_pdu( + self, origin: str, event_id: str + ) -> Optional[EventBase]: """Get an event from the database for the given server. Args: - origin [str]: hostname of server which is requesting the event; we + origin: hostname of server which is requesting the event; we will check that the server is allowed to see it. - event_id [str]: id of the event being requested + event_id: id of the event being requested Returns: - Deferred[EventBase|None]: None if we know nothing about the event; - otherwise the (possibly-redacted) event. + None if we know nothing about the event; otherwise the (possibly-redacted) event. Raises: AuthError if the server is not currently in the room """ - event = yield self.store.get_event( + event = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) if event: - in_room = yield self.auth.check_host_in_room(event.room_id, origin) + in_room = await self.auth.check_host_in_room(event.room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server(self.store, origin, [event]) + events = await filter_events_for_server(self.storage, origin, [event]) event = events[0] return event else: @@ -1628,11 +1821,10 @@ class FederationHandler(BaseHandler): def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) - @defer.inlineCallbacks - def _handle_new_event( + async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False ): - context = yield self._prep_event( + context = await self._prep_event( origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) @@ -1640,12 +1832,16 @@ class FederationHandler(BaseHandler): # hack around with a try/finally instead. success = False try: - if not event.internal_metadata.is_outlier() and not backfilled: - yield self.action_generator.handle_push_actions_for_event( + if ( + not event.internal_metadata.is_outlier() + and not backfilled + and not context.rejected + ): + await self.action_generator.handle_push_actions_for_event( event, context ) - yield self.persist_events_and_notify( + await self.persist_events_and_notify( [(event, context)], backfilled=backfilled ) success = True @@ -1657,8 +1853,12 @@ class FederationHandler(BaseHandler): return context - @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False): + async def _handle_new_events( + self, + origin: str, + event_infos: Iterable[_NewEventInfo], + backfilled: bool = False, + ) -> None: """Creates the appropriate contexts and persists events. The events should not depend on one another, e.g. this should be used to persist a bunch of outliers, but not a chunk of individual events that depend @@ -1667,36 +1867,41 @@ class FederationHandler(BaseHandler): Notifies about the events where appropriate. """ - @defer.inlineCallbacks - def prep(ev_info): - event = ev_info["event"] + async def prep(ev_info: _NewEventInfo): + event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = yield self._prep_event( + res = await self._prep_event( origin, event, - state=ev_info.get("state"), - auth_events=ev_info.get("auth_events"), + state=ev_info.state, + auth_events=ev_info.auth_events, backfilled=backfilled, ) return res - contexts = yield make_deferred_yieldable( + contexts = await make_deferred_yieldable( defer.gatherResults( [run_in_background(prep, ev_info) for ev_info in event_infos], consumeErrors=True, ) ) - yield self.persist_events_and_notify( + await self.persist_events_and_notify( [ - (ev_info["event"], context) + (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) ], backfilled=backfilled, ) - @defer.inlineCallbacks - def _persist_auth_tree(self, origin, auth_events, state, event): + async def _persist_auth_tree( + self, + origin: str, + auth_events: List[EventBase], + state: List[EventBase], + event: EventBase, + room_version: RoomVersion, + ) -> int: """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -1705,18 +1910,17 @@ class FederationHandler(BaseHandler): Will attempt to fetch missing auth events. Args: - origin (str): Where the events came from - auth_events (list) - state (list) - event (Event) - - Returns: - Deferred + origin: Where the events came from + auth_events + state + event + room_version: The room version we expect this room to have, and + will raise if it doesn't match the version in the create event. """ events_to_context = {} for e in itertools.chain(auth_events, state): e.internal_metadata.outlier = True - ctx = yield self.state_handler.compute_event_context(e) + ctx = await self.state_handler.compute_event_context(e) events_to_context[e.event_id] = ctx event_map = { @@ -1734,10 +1938,13 @@ class FederationHandler(BaseHandler): # invalid, and it would fail auth checks anyway. raise SynapseError(400, "No create event in state") - room_version = create_event.content.get( + room_version_id = create_event.content.get( "room_version", RoomVersions.V1.identifier ) + if room_version.identifier != room_version_id: + raise SynapseError(400, "Room version mismatch") + missing_auth_events = set() for e in itertools.chain(auth_events, state, [event]): for e_id in e.auth_event_ids(): @@ -1745,8 +1952,8 @@ class FederationHandler(BaseHandler): missing_auth_events.add(e_id) for e_id in missing_auth_events: - m_ev = yield self.federation_client.get_pdu( - [origin], e_id, room_version=room_version, outlier=True, timeout=10000 + m_ev = await self.federation_client.get_pdu( + [origin], e_id, room_version=room_version, outlier=True, timeout=10000, ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -1763,7 +1970,7 @@ class FederationHandler(BaseHandler): auth_for_e[(EventTypes.Create, "")] = create_event try: - self.auth.check(room_version, e, auth_events=auth_for_e) + event_auth.check(room_version, e, auth_events=auth_for_e) except SynapseError as err: # we may get SynapseErrors here as well as AuthErrors. For # instance, there are a couple of (ancient) events in some @@ -1771,94 +1978,80 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn("Rejecting %s because %s", e.event_id, err.msg) + logger.warning("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - yield self.persist_events_and_notify( + await self.persist_events_and_notify( [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) ] ) - new_event_context = yield self.state_handler.compute_event_context( + new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - yield self.persist_events_and_notify([(event, new_event_context)]) - - @defer.inlineCallbacks - def _prep_event(self, origin, event, state, auth_events, backfilled): - """ + return await self.persist_events_and_notify([(event, new_event_context)]) - Args: - origin: - event: - state: - auth_events: - backfilled (bool) - - Returns: - Deferred, which resolves to synapse.events.snapshot.EventContext - """ - context = yield self.state_handler.compute_event_context(event, old_state=state) + async def _prep_event( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + auth_events: Optional[StateMap[EventBase]], + backfilled: bool, + ) -> EventContext: + context = await self.state_handler.compute_event_context(event, old_state=state) if not auth_events: - prev_state_ids = yield context.get_prev_state_ids(self.store) - auth_events_ids = yield self.auth.compute_auth_events( + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_event_ids(): if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = yield self.store.get_event( + c = await self.store.get_event( event.prev_event_ids()[0], allow_none=True ) if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c - try: - yield self.do_auth(origin, event, context, auth_events=auth_events) - except AuthError as e: - logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) - - context.rejected = RejectedReason.AUTH_ERROR + context = await self.do_auth(origin, event, context, auth_events=auth_events) if not context.rejected: - yield self._check_for_soft_fail(event, state, backfilled) + await self._check_for_soft_fail(event, state, backfilled) if event.type == EventTypes.GuestAccess and not context.rejected: - yield self.maybe_kick_guest_users(event) + await self.maybe_kick_guest_users(event) return context - @defer.inlineCallbacks - def _check_for_soft_fail(self, event, state, backfilled): - """Checks if we should soft fail the event, if so marks the event as + async def _check_for_soft_fail( + self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool + ) -> None: + """Checks if we should soft fail the event; if so, marks the event as such. Args: - event (FrozenEvent) - state (dict|None): The state at the event if we don't have all the - event's prev events - backfilled (bool): Whether the event is from backfill - - Returns: - Deferred + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill """ # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() if do_soft_fail_check: - extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids) prev_event_ids = set(event.prev_event_ids()) @@ -1869,7 +2062,8 @@ class FederationHandler(BaseHandler): do_soft_fail_check = False if do_soft_fail_check: - room_version = yield self.store.get_room_version(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". if state is not None: @@ -1885,19 +2079,19 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = yield self.store.get_state_groups( + state_sets = await self.state_store.get_state_groups( event.room_id, extrem_ids ) state_sets = list(state_sets.values()) state_sets.append(state) - current_state_ids = yield self.state_handler.resolve_events( + current_state_ids = await self.state_handler.resolve_events( room_version, state_sets, event ) current_state_ids = { k: e.event_id for k, e in iteritems(current_state_ids) } else: - current_state_ids = yield self.state_handler.get_current_state_ids( + current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids ) @@ -1913,26 +2107,27 @@ class FederationHandler(BaseHandler): e for k, e in iteritems(current_state_ids) if k in auth_types ] - current_auth_events = yield self.store.get_events(current_state_ids) + current_auth_events = await self.store.get_events(current_state_ids) current_auth_events = { (e.type, e.state_key): e for e in current_auth_events.values() } try: - self.auth.check(room_version, event, auth_events=current_auth_events) + event_auth.check( + room_version_obj, event, auth_events=current_auth_events + ) except AuthError as e: - logger.warn("Soft-failing %r because %s", event, e) + logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True - @defer.inlineCallbacks - def on_query_auth( + async def on_query_auth( self, origin, event_id, room_id, remote_auth_chain, rejects, missing ): - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - event = yield self.store.get_event( + event = await self.store.get_event( event_id, allow_none=False, check_room_id=room_id ) @@ -1940,70 +2135,77 @@ class FederationHandler(BaseHandler): # don't want to fall into the trap of `missing` being wrong. for e in remote_auth_chain: try: - yield self._handle_new_event(origin, e) + await self._handle_new_event(origin, e) except AuthError: pass # Now get the current auth_chain for the event. - local_auth_chain = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + local_auth_chain = await self.store.get_auth_chain( + list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. - ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain) + ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain) logger.debug("on_query_auth returning: %s", ret) return ret - @defer.inlineCallbacks - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") + # Only allow up to 20 events to be retrieved per request. limit = min(limit, 20) - missing_events = yield self.store.get_missing_events( + missing_events = await self.store.get_missing_events( room_id=room_id, earliest_events=earliest_events, latest_events=latest_events, limit=limit, ) - missing_events = yield filter_events_for_server( - self.store, origin, missing_events + missing_events = await filter_events_for_server( + self.storage, origin, missing_events ) return missing_events - @defer.inlineCallbacks - @log_function - def do_auth(self, origin, event, context, auth_events): + async def do_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + auth_events: StateMap[EventBase], + ) -> EventContext: """ Args: - origin (str): - event (synapse.events.EventBase): - context (synapse.events.snapshot.EventContext): - auth_events (dict[(str, str)->synapse.events.EventBase]): + origin: + event: + context: + auth_events: Map from (event_type, state_key) to event - What we expect the event's auth_events to be, based on the event's - position in the dag. I think? maybe?? + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + updated context object """ - room_version = yield self.store.get_room_version(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] try: - yield self._update_auth_events_and_context_for_auth( + context = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2018,15 +2220,20 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(room_version, event, auth_events=auth_events) + event_auth.check(room_version_obj, event, auth_events=auth_events) except AuthError as e: - logger.warn("Failed auth resolution for %r because %s", event, e) - raise e + logger.warning("Failed auth resolution for %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR - @defer.inlineCallbacks - def _update_auth_events_and_context_for_auth( - self, origin, event, context, auth_events - ): + return context + + async def _update_auth_events_and_context_for_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + auth_events: StateMap[EventBase], + ) -> EventContext: """Helper for do_auth. See there for docs. Checks whether a given event has the expected auth events. If it @@ -2034,59 +2241,59 @@ class FederationHandler(BaseHandler): we can come to a consensus (e.g. if one server missed some valid state). - This attempts to resovle any potential divergence of state between + This attempts to resolve any potential divergence of state between servers, but is not essential and so failures should not block further processing of the event. Args: - origin (str): - event (synapse.events.EventBase): - context (synapse.events.snapshot.EventContext): - auth_events (dict[(str, str)->synapse.events.EventBase]): + origin: + event: + context: + + auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + + Also NB that this function adds entries to it. Returns: - defer.Deferred[None] + updated context """ event_auth_events = set(event.auth_event_ids()) - if event.is_state(): - event_key = (event.type, event.state_key) - else: - event_key = None - - # if the event's auth_events refers to events which are not in our - # calculated auth_events, we need to fetch those events from somewhere. - # - # we start by fetching them from the store, and then try calling /event_auth/. + # missing_auth is the set of the event's auth_events which we don't yet have + # in auth_events. missing_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) + # if we have missing events, we need to fetch those events from somewhere. + # + # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections(missing_auth) - logger.debug("Got events %s from store", have_events) - missing_auth.difference_update(have_events.keys()) - else: - have_events = {} - - have_events.update({e.event_id: "" for e in auth_events.values()}) + have_events = await self.store.have_seen_events(missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) if missing_auth: # If we don't have all the auth events, we need to get them. logger.info("auth_events contains unknown events: %s", missing_auth) try: try: - remote_auth_chain = yield self.federation_client.get_event_auth( + remote_auth_chain = await self.federation_client.get_event_auth( origin, event.room_id, event.event_id ) except RequestSendFailed as e: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e) - return + return context - seen_remotes = yield self.store.have_seen_events( + seen_remotes = await self.store.have_seen_events( [e.event_id for e in remote_auth_chain] ) @@ -2109,32 +2316,31 @@ class FederationHandler(BaseHandler): logger.debug( "do_auth %s missing_auth: %s", event.event_id, e.event_id ) - yield self._handle_new_event(origin, e, auth_events=auth) + await self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e except AuthError: pass - have_events = yield self.store.get_seen_events_with_rejections( - event.auth_event_ids() - ) except Exception: - # FIXME: logger.exception("Failed to get auth chain") if event.internal_metadata.is_outlier(): + # XXX: given that, for an outlier, we'll be working with the + # event's *claimed* auth events rather than those we calculated: + # (a) is there any point in this test, since different_auth below will + # obviously be empty + # (b) alternatively, why don't we do it earlier? logger.info("Skipping auth_event fetch for outlier") - return + return context - # FIXME: Assumes we have and stored all the state for all the - # prev_events different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) if not different_auth: - return + return context logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2142,175 +2348,94 @@ class FederationHandler(BaseHandler): different_auth, ) - room_version = yield self.store.get_room_version(event.room_id) + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = await self.store.get_events_as_list(different_auth) - different_events = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.get_event, d, allow_none=True, allow_rejected=False - ) - for d in different_auth - if d in have_events and not have_events[d] - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) - - if different_events: - local_view = dict(auth_events) - remote_view = dict(auth_events) - remote_view.update( - {(d.type, d.state_key): d for d in different_events if d} - ) - - new_state = yield self.state_handler.resolve_events( - room_version, - [list(local_view.values()), list(remote_view.values())], - event, - ) - - logger.info( - "After state res: updating auth_events with new state %s", - { - (d.type, d.state_key): d.event_id - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, - ) - - auth_events.update(new_state) - - different_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - yield self._update_context_for_auth_events( - event, context, auth_events, event_key - ) + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) - if not different_auth: - # we're done - return + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context + + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. + + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() + + room_version = await self.store.get_room_version_id(event.room_id) + new_state = await self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) logger.info( - "auth_events still refers to events which are not in the calculated auth " - "chain after state resolution: %s", - different_auth, + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, ) - # Only do auth resolution if we have something new to say. - # We can't prove an auth failure. - do_resolution = False - - for e_id in different_auth: - if e_id in have_events: - if have_events[e_id] == RejectedReason.NOT_ANCESTOR: - do_resolution = True - break - - if not do_resolution: - logger.info( - "Skipping auth resolution due to lack of provable rejection reasons" - ) - return - - logger.info("Doing auth resolution") - - prev_state_ids = yield context.get_prev_state_ids(self.store) - - # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids) - local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True) - - try: - # 2. Get remote difference. - try: - result = yield self.federation_client.query_auth( - origin, event.room_id, event.event_id, local_auth_chain - ) - except RequestSendFailed as e: - # The other side isn't around or doesn't implement the - # endpoint, so lets just bail out. - logger.info("Failed to query auth from remote: %s", e) - return - - seen_remotes = yield self.store.have_seen_events( - [e.event_id for e in result["auth_chain"]] - ) - - # 3. Process any remote auth chain events we haven't seen. - for ev in result["auth_chain"]: - if ev.event_id in seen_remotes: - continue - - if ev.event_id == event.event_id: - continue - - try: - auth_ids = ev.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in result["auth_chain"] - if e.event_id in auth_ids or event.type == EventTypes.Create - } - ev.internal_metadata.outlier = True - - logger.debug( - "do_auth %s different_auth: %s", event.event_id, e.event_id - ) - - yield self._handle_new_event(origin, ev, auth_events=auth) - - if ev.event_id in event_auth_events: - auth_events[(ev.type, ev.state_key)] = ev - except AuthError: - pass - - except Exception: - # FIXME: - logger.exception("Failed to query auth chain") - - # 4. Look at rejects and their proofs. - # TODO. + auth_events.update(new_state) - yield self._update_context_for_auth_events( - event, context, auth_events, event_key + context = await self._update_context_for_auth_events( + event, context, auth_events ) - @defer.inlineCallbacks - def _update_context_for_auth_events(self, event, context, auth_events, event_key): + return context + + async def _update_context_for_auth_events( + self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] + ) -> EventContext: """Update the state_ids in an event context after auth event resolution, storing the changes as a new state group. Args: - event (Event): The event we're handling the context for + event: The event we're handling the context for - context (synapse.events.snapshot.EventContext): event context - to be updated + context: initial event context - auth_events (dict[(str, str)->str]): Events to update in the event - context. + auth_events: Events to update in the event context. - event_key ((str, str)): (type, state_key) for the current event. - this will not be included in the current_state in the context. + Returns: + new event context """ + # exclude the state key of the new event from the current_state in the context. + if event.is_state(): + event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]] + else: + event_key = None state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - current_state_ids = yield context.get_current_state_ids(self.store) + + current_state_ids = await context.get_current_state_ids() current_state_ids = dict(current_state_ids) current_state_ids.update(state_updates) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = yield self.store.store_state_group( + state_group = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -2318,16 +2443,18 @@ class FederationHandler(BaseHandler): current_state_ids=current_state_ids, ) - yield context.update_state( + return EventContext.with_state( state_group=state_group, + state_group_before_event=context.state_group_before_event, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, prev_group=prev_group, delta_ids=state_updates, ) - @defer.inlineCallbacks - def construct_auth_difference(self, local_auth, remote_auth): + async def construct_auth_difference( + self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] + ) -> Dict: """ Given a local and remote auth chain, find the differences. This assumes that we have already processed all events in remote_auth @@ -2436,7 +2563,7 @@ class FederationHandler(BaseHandler): reason_map = {} for e in base_remote_rejected: - reason = yield self.store.get_rejection_reason(e.event_id) + reason = await self.store.get_rejection_reason(e.event_id) if reason is None: # TODO: e is not in the current state, so we should # construct some proof of that. @@ -2444,15 +2571,6 @@ class FederationHandler(BaseHandler): reason_map[e.event_id] = reason - if reason == RejectedReason.AUTH_ERROR: - pass - elif reason == RejectedReason.REPLACED: - # TODO: Get proof - pass - elif reason == RejectedReason.NOT_ANCESTOR: - # TODO: Get proof. - pass - logger.debug("construct_auth_difference returning") return { @@ -2464,9 +2582,8 @@ class FederationHandler(BaseHandler): "missing": [e.event_id for e in missing_locals], } - @defer.inlineCallbacks @log_function - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): third_party_invite = {"signed": signed} @@ -2482,16 +2599,16 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): - room_version = yield self.store.get_room_version(room_id) + if await self.auth.check_host_in_room(room_id, self.hs.hostname): + room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2503,58 +2620,58 @@ class FederationHandler(BaseHandler): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - event, context = yield self.add_display_name_to_third_party_invite( + event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even # though the sender isn't a local user. event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying new third party invite %r because %s", event, e) + logger.warning("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, context) + await self._check_signature(event, context) + + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() - yield member_handler.send_membership_event(None, event, context) + await member_handler.send_membership_event(None, event, context) else: - destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) - yield self.federation_client.forward_third_party_invite( + destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} + await self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) - @defer.inlineCallbacks - @log_function - def on_exchange_third_party_invite_request(self, room_id, event_dict): + async def on_exchange_third_party_invite_request( + self, room_id: str, event_dict: JsonDict + ) -> None: """Handle an exchange_third_party_invite request from a remote server The remote server will call this when it wants to turn a 3pid invite into a normal m.room.member invite. Args: - room_id (str): The ID of the room. + room_id: The ID of the room. event_dict (dict[str, Any]): Dictionary containing the event body. - Returns: - Deferred: resolves (to None) """ - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. builder = self.event_builder_factory.new(room_version, event_dict) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2565,26 +2682,26 @@ class FederationHandler(BaseHandler): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - event, context = yield self.add_display_name_to_third_party_invite( + event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) try: - self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying third party invite %r because %s", event, e) + logger.warning("Denying third party invite %r because %s", event, e) raise e - yield self._check_signature(event, context) + await self._check_signature(event, context) # We need to tell the transaction queue to send this out, even # though the sender isn't a local user. event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender) + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() - yield member_handler.send_membership_event(None, event, context) + await member_handler.send_membership_event(None, event, context) - @defer.inlineCallbacks - def add_display_name_to_third_party_invite( + async def add_display_name_to_third_party_invite( self, room_version, event_dict, event, context ): key = ( @@ -2592,14 +2709,19 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() original_invite_id = prev_state_ids.get(key) if original_invite_id: - original_invite = yield self.store.get_event( + original_invite = await self.store.get_event( original_invite_id, allow_none=True ) if original_invite: - display_name = original_invite.content["display_name"] + # If the m.room.third_party_invite event's content is empty, it means the + # invite has been revoked. In this case, we don't have to raise an error here + # because the auth check will fail on the invite (because it's not able to + # fetch public keys from the m.room.third_party_invite event's content, which + # is empty). + display_name = original_invite.content.get("display_name") event_dict["content"]["third_party_invite"]["display_name"] = display_name else: logger.info( @@ -2611,14 +2733,13 @@ class FederationHandler(BaseHandler): builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) return (event, context) - @defer.inlineCallbacks - def _check_signature(self, event, context): + async def _check_signature(self, event, context): """ Checks that the signature in the event is consistent with its invite. @@ -2635,12 +2756,12 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None if invite_event_id: - invite_event = yield self.store.get_event(invite_event_id, allow_none=True) + invite_event = await self.store.get_event(invite_event_id, allow_none=True) if not invite_event: raise AuthError(403, "Could not find invite") @@ -2689,7 +2810,7 @@ class FederationHandler(BaseHandler): raise try: if "key_validity_url" in public_key_object: - yield self._check_key_revocation( + await self._check_key_revocation( public_key, public_key_object["key_validity_url"] ) except Exception: @@ -2703,8 +2824,7 @@ class FederationHandler(BaseHandler): last_exception = e raise last_exception - @defer.inlineCallbacks - def _check_key_revocation(self, public_key, url): + async def _check_key_revocation(self, public_key, url): """ Checks whether public_key has been revoked. @@ -2718,47 +2838,58 @@ class FederationHandler(BaseHandler): for revocation. """ try: - response = yield self.http_client.get_json(url, {"public_key": public_key}) + response = await self.http_client.get_json(url, {"public_key": public_key}) except Exception: raise SynapseError(502, "Third party certificate could not be checked") if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") - @defer.inlineCallbacks - def persist_events_and_notify(self, event_and_contexts, backfilled=False): + async def persist_events_and_notify( + self, + event_and_contexts: Sequence[Tuple[EventBase, EventContext]], + backfilled: bool = False, + ) -> int: """Persists events and tells the notifier/pushers about them, if necessary. Args: - event_and_contexts(list[tuple[FrozenEvent, EventContext]]) - backfilled (bool): Whether these events are a result of + event_and_contexts: + backfilled: Whether these events are a result of backfilling or not - - Returns: - Deferred """ - if self.config.worker_app: - yield self._send_events_to_master( + if self.config.worker.writers.events != self._instance_name: + result = await self._send_events( + instance_name=self.config.worker.writers.events, store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, ) + return result["max_stream_id"] else: - max_stream_id = yield self.store.persist_events( + max_stream_id = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - yield self._notify_persisted_event(event, max_stream_id) + await self._notify_persisted_event(event, max_stream_id) + + return max_stream_id - def _notify_persisted_event(self, event, max_stream_id): + async def _notify_persisted_event( + self, event: EventBase, max_stream_id: int + ) -> None: """Checks to see if notifier/pushers should be notified about the event or not. Args: - event (FrozenEvent) - max_stream_id (int): The max_stream_id returned by persist_events + event: + max_stream_id: The max_stream_id returned by persist_events """ extra_users = [] @@ -2782,32 +2913,31 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) - def _clean_room_for_join(self, room_id): + async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the server to join the room. Args: - room_id (str) + room_id """ if self.config.worker_app: - return self._clean_room_for_join_client(room_id) + await self._clean_room_for_join_client(room_id) else: - return self.store.clean_room_for_join(room_id) + await self.store.clean_room_for_join(room_id) - def user_joined_room(self, user, room_id): + async def user_joined_room(self, user: UserID, room_id: str) -> None: """Called when a new user has joined the room """ if self.config.worker_app: - return self._notify_user_membership_change( + await self._notify_user_membership_change( room_id=room_id, user_id=user.to_string(), change="joined" ) else: - return user_joined_room(self.distributor, user, room_id) + user_joined_room(self.distributor, user, room_id) - @defer.inlineCallbacks - def get_room_complexity(self, remote_room_hosts, room_id): + async def get_room_complexity(self, remote_room_hosts, room_id): """ Fetch the complexity of a remote room over federation. @@ -2821,12 +2951,12 @@ class FederationHandler(BaseHandler): """ for host in remote_room_hosts: - res = yield self.federation_client.get_room_complexity(host, room_id) + res = await self.federation_client.get_room_complexity(host, room_id) # We got a result, return it. if res: - defer.returnValue(res) + return res # We fell off the bottom, couldn't get the complexity from anyone. Oh # well. - defer.returnValue(None) + return None diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 46eb9ee88b..ebe8d25bd8 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -18,8 +18,6 @@ import logging from six import iteritems -from twisted.internet import defer - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.types import get_domain_from_id @@ -63,7 +61,7 @@ def _create_rerouter(func_name): return f -class GroupsLocalHandler(object): +class GroupsLocalWorkerHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() @@ -81,55 +79,33 @@ class GroupsLocalHandler(object): self.profile_handler = hs.get_profile_handler() - # Ensure attestations get renewed - hs.get_groups_attestation_renewer() - # The following functions merely route the query to the local groups server # or federation depending on if the group is local or remote get_group_profile = _create_rerouter("get_group_profile") - update_group_profile = _create_rerouter("update_group_profile") get_rooms_in_group = _create_rerouter("get_rooms_in_group") - get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") - - add_room_to_group = _create_rerouter("add_room_to_group") - update_room_in_group = _create_rerouter("update_room_in_group") - remove_room_from_group = _create_rerouter("remove_room_from_group") - - update_group_summary_room = _create_rerouter("update_group_summary_room") - delete_group_summary_room = _create_rerouter("delete_group_summary_room") - - update_group_category = _create_rerouter("update_group_category") - delete_group_category = _create_rerouter("delete_group_category") get_group_category = _create_rerouter("get_group_category") get_group_categories = _create_rerouter("get_group_categories") - - update_group_summary_user = _create_rerouter("update_group_summary_user") - delete_group_summary_user = _create_rerouter("delete_group_summary_user") - - update_group_role = _create_rerouter("update_group_role") - delete_group_role = _create_rerouter("delete_group_role") get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") - set_group_join_policy = _create_rerouter("set_group_join_policy") - - @defer.inlineCallbacks - def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary(self, group_id, requester_user_id): """Get the group summary for a group. If the group is remote we check that the users have valid attestations. """ if self.is_mine_id(group_id): - res = yield self.groups_server_handler.get_group_summary( + res = await self.groups_server_handler.get_group_summary( group_id, requester_user_id ) else: try: - res = yield self.transport_client.get_group_summary( + res = await self.transport_client.get_group_summary( get_domain_from_id(group_id), group_id, requester_user_id ) + except HttpResponseException as e: + raise e.to_synapse_error() except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") @@ -143,7 +119,7 @@ class GroupsLocalHandler(object): attestation = entry.pop("attestation", {}) try: if get_domain_from_id(g_user_id) != group_server_name: - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, group_id=group_id, user_id=g_user_id, @@ -160,22 +136,155 @@ class GroupsLocalHandler(object): # Add `is_publicised` flag to indicate whether the user has publicised their # membership of the group on their profile - result = yield self.store.get_publicised_groups_for_user(requester_user_id) + result = await self.store.get_publicised_groups_for_user(requester_user_id) is_publicised = group_id in result res.setdefault("user", {})["is_publicised"] = is_publicised return res - @defer.inlineCallbacks - def create_group(self, group_id, user_id, content): + async def get_users_in_group(self, group_id, requester_user_id): + """Get users in a group + """ + if self.is_mine_id(group_id): + res = await self.groups_server_handler.get_users_in_group( + group_id, requester_user_id + ) + return res + + group_server_name = get_domain_from_id(group_id) + + try: + res = await self.transport_client.get_users_in_group( + get_domain_from_id(group_id), group_id, requester_user_id + ) + except HttpResponseException as e: + raise e.to_synapse_error() + except RequestSendFailed: + raise SynapseError(502, "Failed to contact group server") + + chunk = res["chunk"] + valid_entries = [] + for entry in chunk: + g_user_id = entry["user_id"] + attestation = entry.pop("attestation", {}) + try: + if get_domain_from_id(g_user_id) != group_server_name: + await self.attestations.verify_attestation( + attestation, + group_id=group_id, + user_id=g_user_id, + server_name=get_domain_from_id(g_user_id), + ) + valid_entries.append(entry) + except Exception as e: + logger.info("Failed to verify user is in group: %s", e) + + res["chunk"] = valid_entries + + return res + + async def get_joined_groups(self, user_id): + group_ids = await self.store.get_joined_groups(user_id) + return {"groups": group_ids} + + async def get_publicised_groups_for_user(self, user_id): + if self.hs.is_mine_id(user_id): + result = await self.store.get_publicised_groups_for_user(user_id) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + result.extend(app_service.get_groups_for_user(user_id)) + + return {"groups": result} + else: + try: + bulk_result = await self.transport_client.bulk_get_publicised_groups( + get_domain_from_id(user_id), [user_id] + ) + except HttpResponseException as e: + raise e.to_synapse_error() + except RequestSendFailed: + raise SynapseError(502, "Failed to contact group server") + + result = bulk_result.get("users", {}).get(user_id) + # TODO: Verify attestations + return {"groups": result} + + async def bulk_get_publicised_groups(self, user_ids, proxy=True): + destinations = {} + local_users = set() + + for user_id in user_ids: + if self.hs.is_mine_id(user_id): + local_users.add(user_id) + else: + destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) + + if not proxy and destinations: + raise SynapseError(400, "Some user_ids are not local") + + results = {} + failed_results = [] + for destination, dest_user_ids in iteritems(destinations): + try: + r = await self.transport_client.bulk_get_publicised_groups( + destination, list(dest_user_ids) + ) + results.update(r["users"]) + except Exception: + failed_results.extend(dest_user_ids) + + for uid in local_users: + results[uid] = await self.store.get_publicised_groups_for_user(uid) + + # Check AS associated groups for this user - this depends on the + # RegExps in the AS registration file (under `users`) + for app_service in self.store.get_app_services(): + results[uid].extend(app_service.get_groups_for_user(uid)) + + return {"users": results} + + +class GroupsLocalHandler(GroupsLocalWorkerHandler): + def __init__(self, hs): + super(GroupsLocalHandler, self).__init__(hs) + + # Ensure attestations get renewed + hs.get_groups_attestation_renewer() + + # The following functions merely route the query to the local groups server + # or federation depending on if the group is local or remote + + update_group_profile = _create_rerouter("update_group_profile") + + add_room_to_group = _create_rerouter("add_room_to_group") + update_room_in_group = _create_rerouter("update_room_in_group") + remove_room_from_group = _create_rerouter("remove_room_from_group") + + update_group_summary_room = _create_rerouter("update_group_summary_room") + delete_group_summary_room = _create_rerouter("delete_group_summary_room") + + update_group_category = _create_rerouter("update_group_category") + delete_group_category = _create_rerouter("delete_group_category") + + update_group_summary_user = _create_rerouter("update_group_summary_user") + delete_group_summary_user = _create_rerouter("delete_group_summary_user") + + update_group_role = _create_rerouter("update_group_role") + delete_group_role = _create_rerouter("delete_group_role") + + set_group_join_policy = _create_rerouter("set_group_join_policy") + + async def create_group(self, group_id, user_id, content): """Create a group """ logger.info("Asking to create group with ID: %r", group_id) if self.is_mine_id(group_id): - res = yield self.groups_server_handler.create_group( + res = await self.groups_server_handler.create_group( group_id, user_id, content ) local_attestation = None @@ -184,17 +293,19 @@ class GroupsLocalHandler(object): local_attestation = self.attestations.create_attestation(group_id, user_id) content["attestation"] = local_attestation - content["user_profile"] = yield self.profile_handler.get_profile(user_id) + content["user_profile"] = await self.profile_handler.get_profile(user_id) try: - res = yield self.transport_client.create_group( + res = await self.transport_client.create_group( get_domain_from_id(group_id), group_id, user_id, content ) + except HttpResponseException as e: + raise e.to_synapse_error() except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -202,7 +313,7 @@ class GroupsLocalHandler(object): ) is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -215,52 +326,11 @@ class GroupsLocalHandler(object): return res - @defer.inlineCallbacks - def get_users_in_group(self, group_id, requester_user_id): - """Get users in a group - """ - if self.is_mine_id(group_id): - res = yield self.groups_server_handler.get_users_in_group( - group_id, requester_user_id - ) - return res - - group_server_name = get_domain_from_id(group_id) - - try: - res = yield self.transport_client.get_users_in_group( - get_domain_from_id(group_id), group_id, requester_user_id - ) - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - chunk = res["chunk"] - valid_entries = [] - for entry in chunk: - g_user_id = entry["user_id"] - attestation = entry.pop("attestation", {}) - try: - if get_domain_from_id(g_user_id) != group_server_name: - yield self.attestations.verify_attestation( - attestation, - group_id=group_id, - user_id=g_user_id, - server_name=get_domain_from_id(g_user_id), - ) - valid_entries.append(entry) - except Exception as e: - logger.info("Failed to verify user is in group: %s", e) - - res["chunk"] = valid_entries - - return res - - @defer.inlineCallbacks - def join_group(self, group_id, user_id, content): + async def join_group(self, group_id, user_id, content): """Request to join a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.join_group(group_id, user_id, content) + await self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -268,15 +338,17 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation try: - res = yield self.transport_client.join_group( + res = await self.transport_client.join_group( get_domain_from_id(group_id), group_id, user_id, content ) + except HttpResponseException as e: + raise e.to_synapse_error() except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -286,7 +358,7 @@ class GroupsLocalHandler(object): # TODO: Check that the group is public and we're being added publically is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -299,12 +371,11 @@ class GroupsLocalHandler(object): return {} - @defer.inlineCallbacks - def accept_invite(self, group_id, user_id, content): + async def accept_invite(self, group_id, user_id, content): """Accept an invite to a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.accept_invite(group_id, user_id, content) + await self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -312,15 +383,17 @@ class GroupsLocalHandler(object): content["attestation"] = local_attestation try: - res = yield self.transport_client.accept_group_invite( + res = await self.transport_client.accept_group_invite( get_domain_from_id(group_id), group_id, user_id, content ) + except HttpResponseException as e: + raise e.to_synapse_error() except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -330,7 +403,7 @@ class GroupsLocalHandler(object): # TODO: Check that the group is public and we're being added publically is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -343,31 +416,31 @@ class GroupsLocalHandler(object): return {} - @defer.inlineCallbacks - def invite(self, group_id, user_id, requester_user_id, config): + async def invite(self, group_id, user_id, requester_user_id, config): """Invite a user to a group """ content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): - res = yield self.groups_server_handler.invite_to_group( + res = await self.groups_server_handler.invite_to_group( group_id, user_id, requester_user_id, content ) else: try: - res = yield self.transport_client.invite_to_group( + res = await self.transport_client.invite_to_group( get_domain_from_id(group_id), group_id, user_id, requester_user_id, content, ) + except HttpResponseException as e: + raise e.to_synapse_error() except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") return res - @defer.inlineCallbacks - def on_invite(self, group_id, user_id, content): + async def on_invite(self, group_id, user_id, content): """One of our users were invited to a group """ # TODO: Support auto join and rejection @@ -382,7 +455,7 @@ class GroupsLocalHandler(object): if "avatar_url" in content["profile"]: local_profile["avatar_url"] = content["profile"]["avatar_url"] - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="invite", @@ -390,19 +463,20 @@ class GroupsLocalHandler(object): ) self.notifier.on_new_event("groups_key", token, users=[user_id]) try: - user_profile = yield self.profile_handler.get_profile(user_id) + user_profile = await self.profile_handler.get_profile(user_id) except Exception as e: - logger.warn("No profile for user %s: %s", user_id, e) + logger.warning("No profile for user %s: %s", user_id, e) user_profile = {} return {"state": "invite", "user_profile": user_profile} - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from a group """ if user_id == requester_user_id: - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" ) self.notifier.on_new_event("groups_key", token, users=[user_id]) @@ -411,93 +485,31 @@ class GroupsLocalHandler(object): # retry if the group server is currently down. if self.is_mine_id(group_id): - res = yield self.groups_server_handler.remove_user_from_group( + res = await self.groups_server_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id try: - res = yield self.transport_client.remove_user_from_group( + res = await self.transport_client.remove_user_from_group( get_domain_from_id(group_id), group_id, requester_user_id, user_id, content, ) + except HttpResponseException as e: + raise e.to_synapse_error() except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") return res - @defer.inlineCallbacks - def user_removed_from_group(self, group_id, user_id, content): + async def user_removed_from_group(self, group_id, user_id, content): """One of our users was removed/kicked from a group """ # TODO: Check if user in group - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" ) self.notifier.on_new_event("groups_key", token, users=[user_id]) - - @defer.inlineCallbacks - def get_joined_groups(self, user_id): - group_ids = yield self.store.get_joined_groups(user_id) - return {"groups": group_ids} - - @defer.inlineCallbacks - def get_publicised_groups_for_user(self, user_id): - if self.hs.is_mine_id(user_id): - result = yield self.store.get_publicised_groups_for_user(user_id) - - # Check AS associated groups for this user - this depends on the - # RegExps in the AS registration file (under `users`) - for app_service in self.store.get_app_services(): - result.extend(app_service.get_groups_for_user(user_id)) - - return {"groups": result} - else: - try: - bulk_result = yield self.transport_client.bulk_get_publicised_groups( - get_domain_from_id(user_id), [user_id] - ) - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - result = bulk_result.get("users", {}).get(user_id) - # TODO: Verify attestations - return {"groups": result} - - @defer.inlineCallbacks - def bulk_get_publicised_groups(self, user_ids, proxy=True): - destinations = {} - local_users = set() - - for user_id in user_ids: - if self.hs.is_mine_id(user_id): - local_users.add(user_id) - else: - destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) - - if not proxy and destinations: - raise SynapseError(400, "Some user_ids are not local") - - results = {} - failed_results = [] - for destination, dest_user_ids in iteritems(destinations): - try: - r = yield self.transport_client.bulk_get_publicised_groups( - destination, list(dest_user_ids) - ) - results.update(r["users"]) - except Exception: - failed_results.extend(dest_user_ids) - - for uid in local_users: - results[uid] = yield self.store.get_publicised_groups_for_user(uid) - - # Check AS associated groups for this user - this depends on the - # RegExps in the AS registration file (under `users`) - for app_service in self.store.get_app_services(): - results[uid].extend(app_service.get_groups_for_user(uid)) - - return {"users": results} diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 512f38e5a6..4ba0042768 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -18,74 +18,57 @@ """Utilities for interacting with Identity Servers""" import logging +import urllib.parse from canonicaljson import json +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 -from twisted.internet import defer +from twisted.internet.error import TimeoutError from synapse.api.errors import ( + AuthError, CodeMessageException, Codes, HttpResponseException, SynapseError, ) -from synapse.util.stringutils import random_string +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.http.client import SimpleHttpClient +from synapse.util.hash import sha256_and_url_safe_base64 +from synapse.util.stringutils import assert_valid_client_secret, random_string from ._base import BaseHandler logger = logging.getLogger(__name__) +id_server_scheme = "https://" + class IdentityHandler(BaseHandler): def __init__(self, hs): super(IdentityHandler, self).__init__(hs) - self.http_client = hs.get_simple_http_client() + self.http_client = SimpleHttpClient(hs) + # We create a blacklisting instance of SimpleHttpClient for contacting identity + # servers specified by clients + self.blacklisting_http_client = SimpleHttpClient( + hs, ip_blacklist=hs.config.federation_ip_range_blacklist + ) self.federation_http_client = hs.get_http_client() self.hs = hs - def _extract_items_from_creds_dict(self, creds): - """ - Retrieve entries from a "credentials" dictionary - - Args: - creds (dict[str, str]): Dictionary of credentials that contain the following keys: - * client_secret|clientSecret: A unique secret str provided by the client - * id_server|idServer: the domain of the identity server to query - * id_access_token: The access token to authenticate to the identity - server with. - - Returns: - tuple(str, str, str|None): A tuple containing the client_secret, the id_server, - and the id_access_token value if available. - """ - client_secret = creds.get("client_secret") or creds.get("clientSecret") - if not client_secret: - raise SynapseError( - 400, "No client_secret in creds", errcode=Codes.MISSING_PARAM - ) - - id_server = creds.get("id_server") or creds.get("idServer") - if not id_server: - raise SynapseError( - 400, "No id_server in creds", errcode=Codes.MISSING_PARAM - ) - - id_access_token = creds.get("id_access_token") - return client_secret, id_server, id_access_token - - @defer.inlineCallbacks - def threepid_from_creds(self, id_server, creds): + async def threepid_from_creds(self, id_server, creds): """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server Args: - id_server (str|None): The identity server to validate 3PIDs against. If None, - we will attempt to extract id_server creds + id_server (str): The identity server to validate 3PIDs against. Must be a + complete URL including the protocol (http(s)://) creds (dict[str, str]): Dictionary containing the following keys: - * id_server|idServer: An optional domain name of an identity server * client_secret|clientSecret: A unique secret str provided by the client * sid: The ID of the validation session @@ -99,56 +82,65 @@ class IdentityHandler(BaseHandler): raise SynapseError( 400, "Missing param client_secret in creds", errcode=Codes.MISSING_PARAM ) + assert_valid_client_secret(client_secret) + session_id = creds.get("sid") if not session_id: raise SynapseError( 400, "Missing param session_id in creds", errcode=Codes.MISSING_PARAM ) - if not id_server: - # Attempt to get the id_server from the creds dict - id_server = creds.get("id_server") or creds.get("idServer") - if not id_server: - raise SynapseError( - 400, "Missing param id_server in creds", errcode=Codes.MISSING_PARAM - ) query_params = {"sid": session_id, "client_secret": client_secret} - url = "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/3pid/getValidated3pid", - ) + url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" - data = yield self.http_client.get_json(url, query_params) - return data if "medium" in data else None + try: + data = await self.http_client.get_json(url, query_params) + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + except HttpResponseException as e: + logger.info( + "%s returned %i for threepid validation for: %s", + id_server, + e.code, + creds, + ) + return None + + # Old versions of Sydent return a 200 http code even on a failed validation + # check. Thus, in addition to the HttpResponseException check above (which + # checks for non-200 errors), we need to make sure validation_session isn't + # actually an error, identified by the absence of a "medium" key + # See https://github.com/matrix-org/sydent/issues/215 for details + if "medium" in data: + return data + + logger.info("%s reported non-validated threepid: %s", id_server, creds) + return None - @defer.inlineCallbacks - def bind_threepid(self, creds, mxid, use_v2=True): + async def bind_threepid( + self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True + ): """Bind a 3PID to an identity server Args: - creds (dict[str, str]): Dictionary of credentials that contain the following keys: - * client_secret|clientSecret: A unique secret str provided by the client - * id_server|idServer: the domain of the identity server to query - * id_access_token: The access token to authenticate to the identity - server with. Required if use_v2 is true + client_secret (str): A unique secret provided by the client + + sid (str): The ID of the validation session + mxid (str): The MXID to bind the 3PID to - use_v2 (bool): Whether to use v2 Identity Service API endpoints + + id_server (str): The domain of the identity server to query + + id_access_token (str): The access token to authenticate to the identity + server with, if necessary. Required if use_v2 is true + + use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True Returns: Deferred[dict]: The response from the identity server """ - logger.debug("binding threepid %r to %s", creds, mxid) - - client_secret, id_server, id_access_token = self._extract_items_from_creds_dict( - creds - ) - - sid = creds.get("sid") - if not sid: - raise SynapseError( - 400, "No sid in three_pid_creds", errcode=Codes.MISSING_PARAM - ) + logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server) # If an id_access_token is not supplied, force usage of v1 if id_access_token is None: @@ -164,13 +156,14 @@ class IdentityHandler(BaseHandler): bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) try: - data = yield self.http_client.post_json_get_json( + # Use the blacklisting http client as this call is only to identity servers + # provided by a client + data = await self.blacklisting_http_client.post_json_get_json( bind_url, bind_data, headers=headers ) - logger.debug("bound threepid %r to %s", creds, mxid) # Remember where we bound the threepid - yield self.store.add_user_bound_threepid( + await self.store.add_user_bound_threepid( user_id=mxid, medium=data["medium"], address=data["address"], @@ -182,15 +175,19 @@ class IdentityHandler(BaseHandler): if e.code != 404 or not use_v2: logger.error("3PID bind failed with Matrix error: %r", e) raise e.to_synapse_error() + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: data = json.loads(e.msg) # XXX WAT? return data logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) - return (yield self.bind_threepid(creds, mxid, use_v2=False)) + res = await self.bind_threepid( + client_secret, sid, mxid, id_server, id_access_token, use_v2=False + ) + return res - @defer.inlineCallbacks - def try_unbind_threepid(self, mxid, threepid): + async def try_unbind_threepid(self, mxid, threepid): """Attempt to remove a 3PID from an identity server, or if one is not provided, all identity servers we're aware the binding is present on @@ -210,7 +207,7 @@ class IdentityHandler(BaseHandler): if threepid.get("id_server"): id_servers = [threepid["id_server"]] else: - id_servers = yield self.store.get_id_servers_user_bound( + id_servers = await self.store.get_id_servers_user_bound( user_id=mxid, medium=threepid["medium"], address=threepid["address"] ) @@ -220,14 +217,13 @@ class IdentityHandler(BaseHandler): changed = True for id_server in id_servers: - changed &= yield self.try_unbind_threepid_with_id_server( + changed &= await self.try_unbind_threepid_with_id_server( mxid, threepid, id_server ) return changed - @defer.inlineCallbacks - def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): + async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): """Removes a binding from an identity server Args: @@ -263,18 +259,24 @@ class IdentityHandler(BaseHandler): headers = {b"Authorization": auth_headers} try: - yield self.http_client.post_json_get_json(url, content, headers) + # Use the blacklisting http client as this call is only to identity servers + # provided by a client + await self.blacklisting_http_client.post_json_get_json( + url, content, headers + ) changed = True except HttpResponseException as e: changed = False if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) - logger.warn("Received %d response while unbinding threepid", e.code) + logger.warning("Received %d response while unbinding threepid", e.code) else: logger.error("Failed to unbind threepid on identity server: %s", e) - raise SynapseError(502, "Failed to contact identity server") + raise SynapseError(500, "Failed to contact identity server") + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") - yield self.store.remove_user_bound_threepid( + await self.store.remove_user_bound_threepid( user_id=mxid, medium=threepid["medium"], address=threepid["address"], @@ -283,8 +285,7 @@ class IdentityHandler(BaseHandler): return changed - @defer.inlineCallbacks - def send_threepid_validation( + async def send_threepid_validation( self, email_address, client_secret, @@ -312,7 +313,7 @@ class IdentityHandler(BaseHandler): """ # Check that this email/client_secret/send_attempt combo is new or # greater than what we've seen previously - session = yield self.store.get_threepid_validation_session( + session = await self.store.get_threepid_validation_session( "email", client_secret, address=email_address, validated=False ) @@ -331,13 +332,22 @@ class IdentityHandler(BaseHandler): # Generate a session id session_id = random_string(16) + if next_link: + # Manipulate the next_link to add the sid, because the caller won't get + # it until we send a response, by which time we've sent the mail. + if "?" in next_link: + next_link += "&" + else: + next_link += "?" + next_link += "sid=" + urllib.parse.quote(session_id) + # Generate a new validation token token = random_string(32) # Send the mail with the link containing the token, client_secret # and session_id try: - yield send_email_func(email_address, token, client_secret, session_id) + await send_email_func(email_address, token, client_secret, session_id) except Exception: logger.exception( "Error sending threepid validation email to %s", email_address @@ -348,7 +358,7 @@ class IdentityHandler(BaseHandler): self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime ) - yield self.store.start_or_continue_validation_session( + await self.store.start_or_continue_validation_session( "email", email_address, session_id, @@ -361,8 +371,7 @@ class IdentityHandler(BaseHandler): return session_id - @defer.inlineCallbacks - def requestEmailToken( + async def requestEmailToken( self, id_server, email, client_secret, send_attempt, next_link=None ): """ @@ -389,7 +398,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " @@ -397,7 +406,7 @@ class IdentityHandler(BaseHandler): ) try: - data = yield self.http_client.post_json_get_json( + data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/email/requestToken", params, ) @@ -405,9 +414,10 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") - @defer.inlineCallbacks - def requestMsisdnToken( + async def requestMsisdnToken( self, id_server, country, @@ -441,7 +451,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " @@ -449,14 +459,441 @@ class IdentityHandler(BaseHandler): ) try: - data = yield self.http_client.post_json_get_json( + data = await self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", params, ) - return data except HttpResponseException as e: logger.info("Proxied requestToken failed: %r", e) raise e.to_synapse_error() + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + + assert self.hs.config.public_baseurl + + # we need to tell the client to send the token back to us, since it doesn't + # otherwise know where to send it, so add submit_url response parameter + # (see also MSC2078) + data["submit_url"] = ( + self.hs.config.public_baseurl + + "_matrix/client/unstable/add_threepid/msisdn/submit_token" + ) + return data + + async def validate_threepid_session(self, client_secret, sid): + """Validates a threepid session with only the client secret and session ID + Tries validating against any configured account_threepid_delegates as well as locally. + + Args: + client_secret (str): A secret provided by the client + + sid (str): The ID of the session + + Returns: + Dict[str, str|int] if validation was successful, otherwise None + """ + # XXX: We shouldn't need to keep wrapping and unwrapping this value + threepid_creds = {"client_secret": client_secret, "sid": sid} + + # We don't actually know which medium this 3PID is. Thus we first assume it's email, + # and if validation fails we try msisdn + validation_session = None + + # Try to validate as email + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Ask our delegated email identity server + validation_session = await self.threepid_from_creds( + self.hs.config.account_threepid_delegate_email, threepid_creds + ) + elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + # Get a validated session matching these details + validation_session = await self.store.get_threepid_validation_session( + "email", client_secret, sid=sid, validated=True + ) + + if validation_session: + return validation_session + + # Try to validate as msisdn + if self.hs.config.account_threepid_delegate_msisdn: + # Ask our delegated msisdn identity server + validation_session = await self.threepid_from_creds( + self.hs.config.account_threepid_delegate_msisdn, threepid_creds + ) + + return validation_session + + async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token): + """Proxy a POST submitToken request to an identity server for verification purposes + + Args: + id_server (str): The identity server URL to contact + + client_secret (str): Secret provided by the client + + sid (str): The ID of the session + + token (str): The verification token + + Raises: + SynapseError: If we failed to contact the identity server + + Returns: + Deferred[dict]: The response dict from the identity server + """ + body = {"client_secret": client_secret, "sid": sid, "token": token} + + try: + return await self.http_client.post_json_get_json( + id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", + body, + ) + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + except HttpResponseException as e: + logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) + raise SynapseError(400, "Error contacting the identity server") + + async def lookup_3pid(self, id_server, medium, address, id_access_token=None): + """Looks up a 3pid in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + id_access_token (str|None): The access token to authenticate to the identity + server with + + Returns: + str|None: the matrix ID of the 3pid, or None if it is not recognized. + """ + if id_access_token is not None: + try: + results = await self._lookup_3pid_v2( + id_server, id_access_token, medium, address + ) + return results + + except Exception as e: + # Catch HttpResponseExcept for a non-200 response code + # Check if this identity server does not know about v2 lookups + if isinstance(e, HttpResponseException) and e.code == 404: + # This is an old identity server that does not yet support v2 lookups + logger.warning( + "Attempted v2 lookup on v1 identity server %s. Falling " + "back to v1", + id_server, + ) + else: + logger.warning("Error when looking up hashing details: %s", e) + return None + + return await self._lookup_3pid_v1(id_server, medium, address) + + async def _lookup_3pid_v1(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server using v1 lookup. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + str: the matrix ID of the 3pid, or None if it is not recognized. + """ + try: + data = await self.blacklisting_http_client.get_json( + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), + {"medium": medium, "address": address}, + ) + + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + await self._verify_any_signature(data, id_server) + return data["mxid"] + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + except IOError as e: + logger.warning("Error from v1 identity server lookup: %s" % (e,)) + + return None + + async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address): + """Looks up a 3pid in the passed identity server using v2 lookup. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + id_access_token (str): The access token to authenticate to the identity server with + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised. + """ + # Check what hashing details are supported by this identity server + try: + hash_details = await self.blacklisting_http_client.get_json( + "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), + {"access_token": id_access_token}, + ) + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + + if not isinstance(hash_details, dict): + logger.warning( + "Got non-dict object when checking hash details of %s%s: %s", + id_server_scheme, + id_server, + hash_details, + ) + raise SynapseError( + 400, + "Non-dict object from %s%s during v2 hash_details request: %s" + % (id_server_scheme, id_server, hash_details), + ) + + # Extract information from hash_details + supported_lookup_algorithms = hash_details.get("algorithms") + lookup_pepper = hash_details.get("lookup_pepper") + if ( + not supported_lookup_algorithms + or not isinstance(supported_lookup_algorithms, list) + or not lookup_pepper + or not isinstance(lookup_pepper, str) + ): + raise SynapseError( + 400, + "Invalid hash details received from identity server %s%s: %s" + % (id_server_scheme, id_server, hash_details), + ) + + # Check if any of the supported lookup algorithms are present + if LookupAlgorithm.SHA256 in supported_lookup_algorithms: + # Perform a hashed lookup + lookup_algorithm = LookupAlgorithm.SHA256 + + # Hash address, medium and the pepper with sha256 + to_hash = "%s %s %s" % (address, medium, lookup_pepper) + lookup_value = sha256_and_url_safe_base64(to_hash) + + elif LookupAlgorithm.NONE in supported_lookup_algorithms: + # Perform a non-hashed lookup + lookup_algorithm = LookupAlgorithm.NONE + + # Combine together plaintext address and medium + lookup_value = "%s %s" % (address, medium) + + else: + logger.warning( + "None of the provided lookup algorithms of %s are supported: %s", + id_server, + supported_lookup_algorithms, + ) + raise SynapseError( + 400, + "Provided identity server does not support any v2 lookup " + "algorithms that this homeserver supports.", + ) + + # Authenticate with identity server given the access token from the client + headers = {"Authorization": create_id_access_token_header(id_access_token)} + + try: + lookup_results = await self.blacklisting_http_client.post_json_get_json( + "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), + { + "addresses": [lookup_value], + "algorithm": lookup_algorithm, + "pepper": lookup_pepper, + }, + headers=headers, + ) + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + except Exception as e: + logger.warning("Error when performing a v2 3pid lookup: %s", e) + raise SynapseError( + 500, "Unknown error occurred during identity server lookup" + ) + + # Check for a mapping from what we looked up to an MXID + if "mappings" not in lookup_results or not isinstance( + lookup_results["mappings"], dict + ): + logger.warning("No results from 3pid lookup") + return None + + # Return the MXID if it's available, or None otherwise + mxid = lookup_results["mappings"].get(lookup_value) + return mxid + + async def _verify_any_signature(self, data, server_hostname): + if server_hostname not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (server_hostname,)) + for key_name, signature in data["signatures"][server_hostname].items(): + try: + key_data = await self.blacklisting_http_client.get_json( + "%s%s/_matrix/identity/api/v1/pubkey/%s" + % (id_server_scheme, server_hostname, key_name) + ) + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + if "public_key" not in key_data: + raise AuthError( + 401, "No public key named %s from %s" % (key_name, server_hostname) + ) + verify_signed_json( + data, + server_hostname, + decode_verify_key_bytes( + key_name, decode_base64(key_data["public_key"]) + ), + ) + return + + async def ask_id_server_for_third_party_invite( + self, + requester, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url, + id_access_token=None, + ): + """ + Asks an identity server for a third party invite. + + Args: + requester (Requester) + id_server (str): hostname + optional port for the identity server. + medium (str): The literal string "email". + address (str): The third party address being invited. + room_id (str): The ID of the room to which the user is invited. + inviter_user_id (str): The user ID of the inviter. + room_alias (str): An alias for the room, for cosmetic notifications. + room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + room_join_rules (str): The join rules of the email (e.g. "public"). + room_name (str): The m.room.name of the room. + inviter_display_name (str): The current display name of the + inviter. + inviter_avatar_url (str): The URL of the inviter's avatar. + id_access_token (str|None): The access token to authenticate to the identity + server with + + Returns: + A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + invite_config = { + "medium": medium, + "address": address, + "room_id": room_id, + "room_alias": room_alias, + "room_avatar_url": room_avatar_url, + "room_join_rules": room_join_rules, + "room_name": room_name, + "sender": inviter_user_id, + "sender_display_name": inviter_display_name, + "sender_avatar_url": inviter_avatar_url, + } + + # Add the identity service access token to the JSON body and use the v2 + # Identity Service endpoints if id_access_token is present + data = None + base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) + + if id_access_token: + key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( + id_server_scheme, + id_server, + ) + + # Attempt a v2 lookup + url = base_url + "/v2/store-invite" + try: + data = await self.blacklisting_http_client.post_json_get_json( + url, + invite_config, + {"Authorization": create_id_access_token_header(id_access_token)}, + ) + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + except HttpResponseException as e: + if e.code != 404: + logger.info("Failed to POST %s with JSON: %s", url, e) + raise e + + if data is None: + key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, + id_server, + ) + url = base_url + "/api/v1/store-invite" + + try: + data = await self.blacklisting_http_client.post_json_get_json( + url, invite_config + ) + except TimeoutError: + raise SynapseError(500, "Timed out contacting identity server") + except HttpResponseException as e: + logger.warning( + "Error trying to call /store-invite on %s%s: %s", + id_server_scheme, + id_server, + e, + ) + + if data is None: + # Some identity servers may only support application/x-www-form-urlencoded + # types. This is especially true with old instances of Sydent, see + # https://github.com/matrix-org/sydent/pull/170 + try: + data = await self.blacklisting_http_client.post_urlencoded_get_json( + url, invite_config + ) + except HttpResponseException as e: + logger.warning( + "Error calling /store-invite on %s%s with fallback " + "encoding: %s", + id_server_scheme, + id_server, + e, + ) + raise e + + # TODO: Check for success + token = data["token"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": key_validity_url, + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) + display_name = data["display_name"] + return token, public_keys, fallback_public_key, display_name def create_id_access_token_header(id_access_token): diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f991efeee3..f88bad5f25 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -18,7 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,8 +41,10 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state def snapshot_all_rooms( self, @@ -77,21 +79,17 @@ class InitialSyncHandler(BaseHandler): as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) - @defer.inlineCallbacks - def _snapshot_all_rooms( + async def _snapshot_all_rooms( self, user_id=None, pagin_config=None, @@ -103,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=memberships ) @@ -111,33 +109,32 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( + presence, _ = await presence_stream.get_pagination_rows( user, pagination_config.get_source_config("presence"), None ) receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( + receipt, _ = await receipt_stream.get_pagination_rows( user, pagination_config.get_source_config("receipt"), None ) - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) + account_data, account_data_by_room = await self.store.get_account_data_for_user( + user_id ) - public_room_ids = yield self.store.get_public_room_ids() + public_room_ids = await self.store.get_public_room_ids() limit = pagin_config.limit if limit is None: limit = 10 - @defer.inlineCallbacks - def handle_room(event): + async def handle_room(event): d = { "room_id": event.room_id, "membership": event.membership, @@ -150,8 +147,8 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() d["inviter"] = event.sender - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = yield self._event_serializer.serialize_event( + invite_event = await self.store.get_event(event.event_id) + d["invite"] = await self._event_serializer.serialize_event( invite_event, time_now, as_client_event ) @@ -169,13 +166,13 @@ class InitialSyncHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, [event.event_id] + self.state_store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] ) - (messages, token), current_state = yield make_deferred_yieldable( + (messages, token), current_state = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -189,7 +186,9 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client(self.store, user_id, messages) + messages = await filter_events_for_client( + self.storage, user_id, messages + ) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -197,7 +196,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event ) ), @@ -205,7 +204,7 @@ class InitialSyncHandler(BaseHandler): "end": end_token.to_string(), } - d["state"] = yield self._event_serializer.serialize_events( + d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -228,7 +227,7 @@ class InitialSyncHandler(BaseHandler): except Exception: logger.exception("Failed to get snapshot") - yield concurrently_execute(handle_room, room_list, 10) + await concurrently_execute(handle_room, room_list, 10) account_data_events = [] for account_data_type, content in account_data.items(): @@ -252,8 +251,7 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync(self, requester, room_id, pagin_config=None): """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. @@ -270,32 +268,35 @@ class InitialSyncHandler(BaseHandler): A JSON serialisable dict with the snapshot of the room. """ - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") user_id = requester.user.to_string() - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id + ( + membership, + member_event_id, + ) = await self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True, ) is_peeking = member_event_id is None if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( + result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) - account_data = yield self.store.get_account_data_for_room(user_id, room_id) + account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): account_data_events.append({"type": account_data_type, "content": content}) @@ -303,11 +304,10 @@ class InitialSyncHandler(BaseHandler): return result - @defer.inlineCallbacks - def _room_initial_sync_parted( + async def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.store.get_state_for_events([member_event_id]) + room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -315,14 +315,14 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event(member_event_id) + stream_token = await self.store.get_stream_token_for_event(member_event_id) - messages, token = yield self.store.get_recent_events_for_room( + messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token ) - messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + messages = await filter_events_for_client( + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace("room_key", token) @@ -335,13 +335,13 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), }, "state": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( room_state.values(), time_now ) ), @@ -349,19 +349,18 @@ class InitialSyncHandler(BaseHandler): "receipts": [], } - @defer.inlineCallbacks - def _room_initial_sync_joined( + async def _room_initial_sync_joined( self, user_id, room_id, pagin_config, membership, is_peeking ): - current_state = yield self.state.get_current_state(room_id=room_id) + current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() - state = yield self._event_serializer.serialize_events( + state = await self._event_serializer.serialize_events( current_state.values(), time_now ) - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: @@ -376,28 +375,32 @@ class InitialSyncHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() - @defer.inlineCallbacks - def get_presence(): + async def get_presence(): # If presence is disabled, return an empty list if not self.hs.config.use_presence: return [] - states = yield presence_handler.get_states( - [m.user_id for m in room_members], as_event=True + states = await presence_handler.get_states( + [m.user_id for m in room_members] ) - return states + return [ + { + "type": EventTypes.Presence, + "content": format_user_presence_state(s, time_now), + } + for s in states + ] - @defer.inlineCallbacks - def get_receipts(): - receipts = yield self.store.get_linearized_receipts_for_room( + async def get_receipts(): + receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] return receipts - presence, receipts, (messages, token) = yield make_deferred_yieldable( + presence, receipts, (messages, token) = await make_deferred_yieldable( defer.gatherResults( [ run_in_background(get_presence), @@ -413,8 +416,8 @@ class InitialSyncHandler(BaseHandler): ).addErrback(unwrapFirstError) ) - messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + messages = await filter_events_for_client( + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) @@ -426,7 +429,7 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), @@ -439,26 +442,3 @@ class InitialSyncHandler(BaseHandler): ret["membership"] = membership return ret - - @defer.inlineCallbacks - def _check_in_room_or_world_readable(self, room_id, user_id): - try: - # check_user_was_in_room will return the most recent membership - # event for the user if: - # * The user is a non-guest user, and was ever in the room - # * The user is a guest user, and has joined the room - # else it will throw. - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - return member_event.membership, member_event.event_id - except AuthError: - visibility = yield self.state_handler.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) - if ( - visibility - and visibility.content["history_visibility"] == "world_readable" - ): - return Membership.JOIN, None - raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN - ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1f8272784e..649ca1f08a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional, Tuple from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -32,14 +40,16 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, UserID, create_requester +from synapse.types import Collection, RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -59,7 +69,19 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -74,15 +96,18 @@ class MessageHandler(object): Raises: SynapseError if something went wrong. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id + ( + membership, + membership_event_id, + ) = yield self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -135,12 +160,12 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events + self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] if visible_events: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -151,8 +176,11 @@ class MessageHandler(object): % (user_id, room_id, at_token), ) else: - membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable(room_id, user_id) + ( + membership, + membership_event_id, + ) = yield self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: @@ -161,7 +189,7 @@ class MessageHandler(object): ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] @@ -192,8 +220,8 @@ class MessageHandler(object): if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. - membership, _ = yield self.auth.check_in_room_or_world_readable( - room_id, user_id + membership, _ = yield self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True ) if membership != Membership.JOIN: raise NotImplementedError( @@ -221,24 +249,129 @@ class MessageHandler(object): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + + +# The duration (in ms) after which rooms should be removed +# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try +# to generate a dummy event for them once more) +# +_DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000 + class EventCreationHandler(object): def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() self.profile_handler = hs.get_profile_handler() self.event_builder_factory = hs.get_event_builder_factory() self.server_name = hs.hostname - self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self._is_event_writer = ( + self.config.worker.writers.events == hs.get_instance_name() + ) - self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) + self.room_invite_state_types = self.hs.config.room_invite_state_types + + self.send_event = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users self.base_handler = BaseHandler(hs) @@ -258,6 +391,13 @@ class EventCreationHandler(object): self.config.block_events_without_consent_error ) + # Rooms which should be excluded from dummy insertion. (For instance, + # those without local users who can send events into the room). + # + # map from room id to time-of-last-attempt. + # + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are # going to need it. @@ -276,6 +416,12 @@ class EventCreationHandler(object): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + + self._dummy_events_threshold = hs.config.dummy_events_threshold + @defer.inlineCallbacks def create_event( self, @@ -283,7 +429,7 @@ class EventCreationHandler(object): event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, + prev_event_ids: Optional[Collection[str]] = None, require_consent=True, ): """ @@ -300,10 +446,9 @@ class EventCreationHandler(object): token_id (str) txn_id (str) - prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + prev_event_ids: the forward extremities to use as the prev_events for the - new event. For each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. + new event. If None, they will be requested from the database. @@ -321,7 +466,9 @@ class EventCreationHandler(object): room_version = event_dict["content"]["room_version"] else: try: - room_version = yield self.store.get_room_version(event_dict["room_id"]) + room_version = yield self.store.get_room_version_id( + event_dict["room_id"] + ) except NotFoundError: raise AuthError(403, "Unknown room") @@ -340,9 +487,13 @@ class EventCreationHandler(object): try: if "displayname" not in content: - content["displayname"] = yield profile.get_displayname(target) + displayname = yield profile.get_displayname(target) + if displayname is not None: + content["displayname"] = displayname if "avatar_url" not in content: - content["avatar_url"] = yield profile.get_avatar_url(target) + avatar_url = yield profile.get_avatar_url(target) + if avatar_url is not None: + content["avatar_url"] = avatar_url except Exception as e: logger.info( "Failed to get profile information for %r: %s", target, e @@ -359,9 +510,7 @@ class EventCreationHandler(object): builder.internal_metadata.txn_id = txn_id event, context = yield self.create_new_client_event( - builder=builder, - requester=requester, - prev_events_and_hashes=prev_events_and_hashes, + builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -376,7 +525,7 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( yield self.store.get_event(prev_event_id, allow_none=True) @@ -398,7 +547,7 @@ class EventCreationHandler(object): 403, "You must be in the room to create an alias for it" ) - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) return (event, context) @@ -484,8 +633,9 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - @defer.inlineCallbacks - def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event( + self, requester, event, context, ratelimit=True + ) -> int: """ Persists and notifies local clients and federation of an event. @@ -494,6 +644,9 @@ class EventCreationHandler(object): context (Context) the context of the event. ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. + + Return: + The stream_id of the persisted event. """ if event.type == EventTypes.Member: raise SynapseError( @@ -505,7 +658,7 @@ class EventCreationHandler(object): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = yield self.deduplicate_state_event(event, context) + prev_state = await self.deduplicate_state_event(event, context) if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", @@ -514,7 +667,7 @@ class EventCreationHandler(object): ) return prev_state - yield self.handle_new_client_event( + return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -526,7 +679,7 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return @@ -541,10 +694,9 @@ class EventCreationHandler(object): return prev_event return - @defer.inlineCallbacks - def create_and_send_nonmember_event( + async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None - ): + ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -556,8 +708,8 @@ class EventCreationHandler(object): # a situation where event persistence can't keep up, causing # extremities to pile up, which in turn leads to state resolution # taking longer. - with (yield self.limiter.queue(event_dict["room_id"])): - event, context = yield self.create_event( + with (await self.limiter.queue(event_dict["room_id"])): + event, context = await self.create_event( requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) @@ -567,15 +719,15 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - yield self.send_nonmember_event( + stream_id = await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) - return event + return event, stream_id @measure_func("create_new_client_event") @defer.inlineCallbacks def create_new_client_event( - self, builder, requester=None, prev_events_and_hashes=None + self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None ): """Create a new event for a local client @@ -584,10 +736,9 @@ class EventCreationHandler(object): requester (synapse.types.Requester|None): - prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + prev_event_ids: the forward extremities to use as the prev_events for the - new event. For each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. + new event. If None, they will be requested from the database. @@ -595,27 +746,20 @@ class EventCreationHandler(object): Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] """ - if prev_events_and_hashes is not None: - assert len(prev_events_and_hashes) <= 10, ( + if prev_event_ids is not None: + assert len(prev_event_ids) <= 10, ( "Attempting to create an event with %i prev_events" - % (len(prev_events_and_hashes),) + % (len(prev_event_ids),) ) else: - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - builder.room_id - ) - - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in prev_events_and_hashes - ] + prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) - event = yield builder.build(prev_event_ids=[p for p, _ in prev_events]) + event = yield builder.build(prev_event_ids=prev_event_ids) context = yield self.state.compute_event_context(event) if requester: context.app_service = requester.app_service - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an @@ -636,10 +780,9 @@ class EventCreationHandler(object): return (event, context) @measure_func("handle_new_client_event") - @defer.inlineCallbacks - def handle_new_client_event( + async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -652,6 +795,9 @@ class EventCreationHandler(object): context (EventContext) ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event + + Return: + The stream_id of the persisted event. """ if event.is_state() and (event.type, event.state_key) == ( @@ -660,9 +806,9 @@ class EventCreationHandler(object): ): room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: - room_version = yield self.store.get_room_version(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -671,9 +817,9 @@ class EventCreationHandler(object): ) try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as err: - logger.warn("Denying new event %r because %s", event, err) + logger.warning("Denying new event %r because %s", event, err) raise err # Ensure that we can round trip before trying to persist in db @@ -684,15 +830,16 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event(event, context) + await self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. success = False try: # If we're a worker we need to hit out to the master. - if self.config.worker_app: - yield self.send_event_to_master( + if not self._is_event_writer: + result = await self.send_event( + instance_name=self.config.worker.writers.events, event_id=event.event_id, store=self.store, requester=requester, @@ -701,14 +848,17 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + stream_id = result["stream_id"] + event.internal_metadata.stream_ordering = stream_id success = True - return + return stream_id - yield self.persist_and_notify_client_event( + stream_id = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True + return stream_id finally: if not success: # Ensure that we actually remove the entries in the push actions @@ -718,15 +868,46 @@ class EventCreationHandler(object): ) @defer.inlineCallbacks - def persist_and_notify_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + def _validate_canonical_alias( + self, directory_handler, room_alias_str, expected_room_id ): + """ + Ensure that the given room alias points to the expected room ID. + + Args: + directory_handler: The directory handler object. + room_alias_str: The room alias to check. + expected_room_id: The room ID that the alias should point to. + """ + room_alias = RoomAlias.from_string(room_alias_str) + try: + mapping = yield directory_handler.get_association(room_alias) + except SynapseError as e: + # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. + if e.errcode == Codes.NOT_FOUND: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + raise + + if mapping["room_id"] != expected_room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + + async def persist_and_notify_client_event( + self, requester, event, context, ratelimit=True, extra_users=[] + ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. - This should only be run on master. + This should only be run on the instance in charge of persisting events. """ - assert not self.config.worker_app + assert self._is_event_writer if ratelimit: # We check if this is a room admin redacting an event so that we @@ -735,9 +916,9 @@ class EventCreationHandler(object): # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -747,24 +928,52 @@ class EventCreationHandler(object): original_event and event.sender != original_event.sender ) - yield self.base_handler.ratelimit( + await self.base_handler.ratelimit( requester, is_admin_redaction=is_admin_redaction ) - yield self.base_handler.maybe_kick_guest_users(event, context) + await self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) + # Validate a newly added alias or newly added alt_aliases. + + original_alias = None + original_alt_aliases = set() + + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_event = await self.store.get_event(original_event_id) + + if original_event: + original_alias = original_event.content.get("alias", None) + original_alt_aliases = original_event.content.get("alt_aliases", []) + + # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - if room_alias_str: - room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" % (room_alias_str,), + directory_handler = self.hs.get_handlers().directory_handler + if room_alias_str and room_alias_str != original_alias: + await self._validate_canonical_alias( + directory_handler, room_alias_str, event.room_id + ) + + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM + ) + + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + await self._validate_canonical_alias( + directory_handler, alias_str, event.room_id ) federation_handler = self.hs.get_handlers().federation_handler @@ -775,16 +984,16 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = await context.get_current_state_ids() state_to_include_ids = [ e_id for k, e_id in iteritems(current_state_ids) - if k[0] in self.hs.config.room_invite_state_types + if k[0] in self.room_invite_state_types or k == (EventTypes.Member, event.sender) ] - state_to_include = yield self.store.get_events(state_to_include_ids) + state_to_include = await self.store.get_events(state_to_include_ids) event.unsigned["invite_room_state"] = [ { @@ -802,19 +1011,18 @@ class EventCreationHandler(object): # way? If we have been invited by a remote server, we need # to get them to sign the event. - returned_invite = yield federation_handler.send_invite( + returned_invite = await federation_handler.send_invite( invitee.domain, event ) - event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -828,15 +1036,19 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids(self.store) - auth_events_ids = yield self.auth.compute_auth_events( + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} - room_version = yield self.store.get_room_version(event.room_id) - if event_auth.check_redaction(room_version, event, auth_events=auth_events): + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + if event_auth.check_redaction( + room_version_obj, event, auth_events=auth_events + ): # this user doesn't have 'redact' rights, so we need to do some more # checks on the original event. Let's start by checking the original # event exists. @@ -850,15 +1062,19 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - (event_stream_id, max_stream_id) = yield self.store.persist_event( + event_stream_id, max_stream_id = await self.storage.persistence.persist_event( event, context=context ) - yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: @@ -875,61 +1091,90 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - @defer.inlineCallbacks - def _bump_active_time(self, user): + return event_stream_id + + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") - @defer.inlineCallbacks - def _send_dummy_events_to_fill_extremities(self): + async def _send_dummy_events_to_fill_extremities(self): """Background task to send dummy events into rooms that have a large number of extremities """ - - room_ids = yield self.store.get_rooms_with_many_extremities( - min_count=10, limit=5 + self._expire_rooms_to_exclude_from_dummy_event_insertion() + room_ids = await self.store.get_rooms_with_many_extremities( + min_count=self._dummy_events_threshold, + limit=5, + room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), ) for room_id in room_ids: # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) - - latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - members = yield self.state.get_current_users_in_room( + members = await self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids ) + dummy_event_sent = False + for user_id in members: + if not self.hs.is_mine_id(user_id): + continue + requester = create_requester(user_id) + try: + event, context = await self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + + event.internal_metadata.proactively_send = False - user_id = None - for member in members: - if self.hs.is_mine_id(member): - user_id = member + await self.send_nonmember_event( + requester, event, context, ratelimit=False + ) + dummy_event_sent = True break + except ConsentNotGivenError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of consent. Will try another user" % (room_id, user_id) + ) + except AuthError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of power. Will try another user" % (room_id, user_id) + ) - if not user_id: - # We don't have a joined user. - # TODO: We should do something here to stop the room from - # appearing next time. - continue - - requester = create_requester(user_id) - - event, context = yield self.create_event( - requester, - { - "type": "org.matrix.dummy_event", - "content": {}, - "room_id": room_id, - "sender": user_id, - }, - prev_events_and_hashes=prev_events_and_hashes, + if not dummy_event_sent: + # Did not find a valid user in the room, so remove from future attempts + # Exclusion is time limited, so the room will be rechecked in the future + # dependent on _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY + logger.info( + "Failed to send dummy event into room %s. Will exclude it from " + "future attempts until cache expires" % (room_id,) + ) + now = self.clock.time_msec() + self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now + + def _expire_rooms_to_exclude_from_dummy_event_insertion(self): + expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY + to_expire = set() + for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items(): + if time < expire_before: + to_expire.add(room_id) + for room_id in to_expire: + logger.debug( + "Expiring room id %s from dummy event insertion exclusion cache", + room_id, ) - - event.internal_metadata.proactively_send = False - - yield self.send_nonmember_event(requester, event, context, ratelimit=False) + del self._rooms_to_exclude_from_dummy_event_insertion[room_id] diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py new file mode 100644 index 0000000000..9c08eb5399 --- /dev/null +++ b/synapse/handlers/oidc_handler.py @@ -0,0 +1,1049 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +from typing import Dict, Generic, List, Optional, Tuple, TypeVar +from urllib.parse import urlencode + +import attr +import pymacaroons +from authlib.common.security import generate_token +from authlib.jose import JsonWebToken +from authlib.oauth2.auth import ClientAuth +from authlib.oauth2.rfc6749.parameters import prepare_grant_uri +from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url +from jinja2 import Environment, Template +from pymacaroons.exceptions import ( + MacaroonDeserializationException, + MacaroonInvalidSignatureException, +) +from typing_extensions import TypedDict + +from twisted.web.client import readBody + +from synapse.config import ConfigError +from synapse.http.server import finish_request +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.push.mailer import load_jinja2_templates +from synapse.server import HomeServer +from synapse.types import UserID, map_username_to_mxid_localpart + +logger = logging.getLogger(__name__) + +SESSION_COOKIE_NAME = b"oidc_session" + +#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and +#: OpenID.Core sec 3.1.3.3. +Token = TypedDict( + "Token", + { + "access_token": str, + "token_type": str, + "id_token": Optional[str], + "refresh_token": Optional[str], + "expires_in": int, + "scope": Optional[str], + }, +) + +#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but +#: there is no real point of doing this in our case. +JWK = Dict[str, str] + +#: A JWK Set, as per RFC7517 sec 5. +JWKS = TypedDict("JWKS", {"keys": List[JWK]}) + + +class OidcError(Exception): + """Used to catch errors when calling the token_endpoint + """ + + def __init__(self, error, error_description=None): + self.error = error + self.error_description = error_description + + def __str__(self): + if self.error_description: + return "{}: {}".format(self.error, self.error_description) + return self.error + + +class MappingException(Exception): + """Used to catch errors when mapping the UserInfo object + """ + + +class OidcHandler: + """Handles requests related to the OpenID Connect login flow. + """ + + def __init__(self, hs: HomeServer): + self._callback_url = hs.config.oidc_callback_url # type: str + self._scopes = hs.config.oidc_scopes # type: List[str] + self._client_auth = ClientAuth( + hs.config.oidc_client_id, + hs.config.oidc_client_secret, + hs.config.oidc_client_auth_method, + ) # type: ClientAuth + self._client_auth_method = hs.config.oidc_client_auth_method # type: str + self._provider_metadata = OpenIDProviderMetadata( + issuer=hs.config.oidc_issuer, + authorization_endpoint=hs.config.oidc_authorization_endpoint, + token_endpoint=hs.config.oidc_token_endpoint, + userinfo_endpoint=hs.config.oidc_userinfo_endpoint, + jwks_uri=hs.config.oidc_jwks_uri, + ) # type: OpenIDProviderMetadata + self._provider_needs_discovery = hs.config.oidc_discover # type: bool + self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( + hs.config.oidc_user_mapping_provider_config + ) # type: OidcMappingProvider + self._skip_verification = hs.config.oidc_skip_verification # type: bool + + self._http_client = hs.get_proxied_http_client() + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + self._datastore = hs.get_datastore() + self._clock = hs.get_clock() + self._hostname = hs.hostname # type: str + self._server_name = hs.config.server_name # type: str + self._macaroon_secret_key = hs.config.macaroon_secret_key + self._error_template = load_jinja2_templates( + hs.config.sso_template_dir, ["sso_error.html"] + )[0] + + # identifier for the external_ids table + self._auth_provider_id = "oidc" + + def _render_error( + self, request, error: str, error_description: Optional[str] = None + ) -> None: + """Renders the error template and respond with it. + + This is used to show errors to the user. The template of this page can + be found under ``synapse/res/templates/sso_error.html``. + + Args: + request: The incoming request from the browser. + We'll respond with an HTML page describing the error. + error: A technical identifier for this error. Those include + well-known OAuth2/OIDC error types like invalid_request or + access_denied. + error_description: A human-readable description of the error. + """ + html_bytes = self._error_template.render( + error=error, error_description=error_description + ).encode("utf-8") + + request.setResponseCode(400) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%i" % len(html_bytes)) + request.write(html_bytes) + finish_request(request) + + def _validate_metadata(self): + """Verifies the provider metadata. + + This checks the validity of the currently loaded provider. Not + everything is checked, only: + + - ``issuer`` + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``response_types_supported`` (checks if "code" is in it) + - ``jwks_uri`` + + Raises: + ValueError: if something in the provider is not valid + """ + # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin) + if self._skip_verification is True: + return + + m = self._provider_metadata + m.validate_issuer() + m.validate_authorization_endpoint() + m.validate_token_endpoint() + + if m.get("token_endpoint_auth_methods_supported") is not None: + m.validate_token_endpoint_auth_methods_supported() + if ( + self._client_auth_method + not in m["token_endpoint_auth_methods_supported"] + ): + raise ValueError( + '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format( + auth_method=self._client_auth_method, + supported=m["token_endpoint_auth_methods_supported"], + ) + ) + + if m.get("response_types_supported") is not None: + m.validate_response_types_supported() + + if "code" not in m["response_types_supported"]: + raise ValueError( + '"code" not in "response_types_supported" (%r)' + % (m["response_types_supported"],) + ) + + # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos + if self._uses_userinfo: + if m.get("userinfo_endpoint") is None: + raise ValueError( + 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested' + ) + else: + # If we're not using userinfo, we need a valid jwks to validate the ID token + if m.get("jwks") is None: + if m.get("jwks_uri") is not None: + m.validate_jwks_uri() + else: + raise ValueError('"jwks_uri" must be set') + + @property + def _uses_userinfo(self) -> bool: + """Returns True if the ``userinfo_endpoint`` should be used. + + This is based on the requested scopes: if the scopes include + ``openid``, the provider should give use an ID token containing the + user informations. If not, we should fetch them using the + ``access_token`` with the ``userinfo_endpoint``. + """ + + # Maybe that should be user-configurable and not inferred? + return "openid" not in self._scopes + + async def load_metadata(self) -> OpenIDProviderMetadata: + """Load and validate the provider metadata. + + The values metadatas are discovered if ``oidc_config.discovery`` is + ``True`` and then cached. + + Raises: + ValueError: if something in the provider is not valid + + Returns: + The provider's metadata. + """ + # If we are using the OpenID Discovery documents, it needs to be loaded once + # FIXME: should there be a lock here? + if self._provider_needs_discovery: + url = get_well_known_url(self._provider_metadata["issuer"], external=True) + metadata_response = await self._http_client.get_json(url) + # TODO: maybe update the other way around to let user override some values? + self._provider_metadata.update(metadata_response) + self._provider_needs_discovery = False + + self._validate_metadata() + + return self._provider_metadata + + async def load_jwks(self, force: bool = False) -> JWKS: + """Load the JSON Web Key Set used to sign ID tokens. + + If we're not using the ``userinfo_endpoint``, user infos are extracted + from the ID token, which is a JWT signed by keys given by the provider. + The keys are then cached. + + Args: + force: Force reloading the keys. + + Returns: + The key set + + Looks like this:: + + { + 'keys': [ + { + 'kid': 'abcdef', + 'kty': 'RSA', + 'alg': 'RS256', + 'use': 'sig', + 'e': 'XXXX', + 'n': 'XXXX', + } + ] + } + """ + if self._uses_userinfo: + # We're not using jwt signing, return an empty jwk set + return {"keys": []} + + # First check if the JWKS are loaded in the provider metadata. + # It can happen either if the provider gives its JWKS in the discovery + # document directly or if it was already loaded once. + metadata = await self.load_metadata() + jwk_set = metadata.get("jwks") + if jwk_set is not None and not force: + return jwk_set + + # Loading the JWKS using the `jwks_uri` metadata + uri = metadata.get("jwks_uri") + if not uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + jwk_set = await self._http_client.get_json(uri) + + # Caching the JWKS in the provider's metadata + self._provider_metadata["jwks"] = jwk_set + return jwk_set + + async def _exchange_code(self, code: str) -> Token: + """Exchange an authorization code for a token. + + This calls the ``token_endpoint`` with the authorization code we + received in the callback to exchange it for a token. The call uses the + ``ClientAuth`` to authenticate with the client with its ID and secret. + + See: + https://tools.ietf.org/html/rfc6749#section-3.2 + https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + + Args: + code: The authorization code we got from the callback. + + Returns: + A dict containing various tokens. + + May look like this:: + + { + 'token_type': 'bearer', + 'access_token': 'abcdef', + 'expires_in': 3599, + 'id_token': 'ghijkl', + 'refresh_token': 'mnopqr', + } + + Raises: + OidcError: when the ``token_endpoint`` returned an error. + """ + metadata = await self.load_metadata() + token_endpoint = metadata.get("token_endpoint") + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": self._http_client.user_agent, + "Accept": "application/json", + } + + args = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self._callback_url, + } + body = urlencode(args, True) + + # Fill the body/headers with credentials + uri, headers, body = self._client_auth.prepare( + method="POST", uri=token_endpoint, headers=headers, body=body + ) + headers = {k: [v] for (k, v) in headers.items()} + + # Do the actual request + # We're not using the SimpleHttpClient util methods as we don't want to + # check the HTTP status code and we do the body encoding ourself. + response = await self._http_client.request( + method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, + ) + + # This is used in multiple error messages below + status = "{code} {phrase}".format( + code=response.code, phrase=response.phrase.decode("utf-8") + ) + + resp_body = await make_deferred_yieldable(readBody(response)) + + if response.code >= 500: + # In case of a server error, we should first try to decode the body + # and check for an error field. If not, we respond with a generic + # error message. + try: + resp = json.loads(resp_body.decode("utf-8")) + error = resp["error"] + description = resp.get("error_description", error) + except (ValueError, KeyError): + # Catch ValueError for the JSON decoding and KeyError for the "error" field + error = "server_error" + description = ( + ( + 'Authorization server responded with a "{status}" error ' + "while exchanging the authorization code." + ).format(status=status), + ) + + raise OidcError(error, description) + + # Since it is a not a 5xx code, body should be a valid JSON. It will + # raise if not. + resp = json.loads(resp_body.decode("utf-8")) + + if "error" in resp: + error = resp["error"] + # In case the authorization server responded with an error field, + # it should be a 4xx code. If not, warn about it but don't do + # anything special and report the original error message. + if response.code < 400: + logger.debug( + "Invalid response from the authorization server: " + 'responded with a "{status}" ' + "but body has an error field: {error!r}".format( + status=status, error=resp["error"] + ) + ) + + description = resp.get("error_description", error) + raise OidcError(error, description) + + # Now, this should not be an error. According to RFC6749 sec 5.1, it + # should be a 200 code. We're a bit more flexible than that, and will + # only throw on a 4xx code. + if response.code >= 400: + description = ( + 'Authorization server responded with a "{status}" error ' + 'but did not include an "error" field in its response.'.format( + status=status + ) + ) + logger.warning(description) + # Body was still valid JSON. Might be useful to log it for debugging. + logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + raise OidcError("server_error", description) + + return resp + + async def _fetch_userinfo(self, token: Token) -> UserInfo: + """Fetch user informations from the ``userinfo_endpoint``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``access_token`` field. + + Returns: + UserInfo: an object representing the user. + """ + metadata = await self.load_metadata() + + resp = await self._http_client.get_json( + metadata["userinfo_endpoint"], + headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, + ) + + return UserInfo(resp) + + async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + """Return an instance of UserInfo from token's ``id_token``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``id_token`` field. + nonce: the nonce value originally sent in the initial authorization + request. This value should match the one inside the token. + + Returns: + An object representing the user. + """ + metadata = await self.load_metadata() + claims_params = { + "nonce": nonce, + "client_id": self._client_auth.client_id, + } + if "access_token" in token: + # If we got an `access_token`, there should be an `at_hash` claim + # in the `id_token` that we can check against. + claims_params["access_token"] = token["access_token"] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + jwt = JsonWebToken(alg_values) + + claim_options = {"iss": {"values": [metadata["issuer"]]}} + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + token["id_token"], + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + token["id_token"], + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + + claims.validate(leeway=120) # allows 2 min of clock skew + return UserInfo(claims) + + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: bytes, + ui_auth_session_id: Optional[str] = None, + ) -> str: + """Handle an incoming request to /login/sso/redirect + + It returns a redirect to the authorization endpoint with a few + parameters: + + - ``client_id``: the client ID set in ``oidc_config.client_id`` + - ``response_type``: ``code`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``scope``: the list of scopes set in ``oidc_config.scopes`` + - ``state``: a random string + - ``nonce``: a random string + + In addition generating a redirect URL, we are setting a cookie with + a signed macaroon token containing the state, the nonce and the + client_redirect_url params. Those are then checked when the client + comes back from the provider. + + Args: + request: the incoming request from the browser. + We'll respond to it with a redirect and a cookie. + client_redirect_url: the URL that we should redirect the client to + when everything is done + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + + Returns: + The redirect URL to the authorization endpoint. + + """ + + state = generate_token() + nonce = generate_token() + + cookie = self._generate_oidc_session_token( + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url.decode(), + ui_auth_session_id=ui_auth_session_id, + ) + request.addCookie( + SESSION_COOKIE_NAME, + cookie, + path="/_synapse/oidc", + max_age="3600", + httpOnly=True, + sameSite="lax", + ) + + metadata = await self.load_metadata() + authorization_endpoint = metadata.get("authorization_endpoint") + return prepare_grant_uri( + authorization_endpoint, + client_id=self._client_auth.client_id, + response_type="code", + redirect_uri=self._callback_url, + scope=self._scopes, + state=state, + nonce=nonce, + ) + + async def handle_oidc_callback(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/oidc/callback + + Since we might want to display OIDC-related errors in a user-friendly + way, we don't raise SynapseError from here. Instead, we call + ``self._render_error`` which displays an HTML page for the error. + + Most of the OpenID Connect logic happens here: + + - first, we check if there was any error returned by the provider and + display it + - then we fetch the session cookie, decode and verify it + - the ``state`` query parameter should match with the one stored in the + session cookie + - once we known this session is legit, exchange the code with the + provider using the ``token_endpoint`` (see ``_exchange_code``) + - once we have the token, use it to either extract the UserInfo from + the ``id_token`` (``_parse_id_token``), or use the ``access_token`` + to fetch UserInfo from the ``userinfo_endpoint`` + (``_fetch_userinfo``) + - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and + finish the login + + Args: + request: the incoming request from the browser. + """ + + # The provider might redirect with an error. + # In that case, just display it as-is. + if b"error" in request.args: + # error response from the auth server. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 + # https://openid.net/specs/openid-connect-core-1_0.html#AuthError + error = request.args[b"error"][0].decode() + description = request.args.get(b"error_description", [b""])[0].decode() + + # Most of the errors returned by the provider could be due by + # either the provider misbehaving or Synapse being misconfigured. + # The only exception of that is "access_denied", where the user + # probably cancelled the login flow. In other cases, log those errors. + if error != "access_denied": + logger.error("Error from the OIDC provider: %s %s", error, description) + + self._render_error(request, error, description) + return + + # otherwise, it is presumably a successful response. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2 + + # Fetch the session cookie + session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] + if session is None: + logger.info("No session cookie found") + self._render_error(request, "missing_session", "No session cookie found") + return + + # Remove the cookie. There is a good chance that if the callback failed + # once, it will fail next time and the code will already be exchanged. + # Removing it early avoids spamming the provider with token requests. + request.addCookie( + SESSION_COOKIE_NAME, + b"", + path="/_synapse/oidc", + expires="Thu, Jan 01 1970 00:00:00 UTC", + httpOnly=True, + sameSite="lax", + ) + + # Check for the state query parameter + if b"state" not in request.args: + logger.info("State parameter is missing") + self._render_error(request, "invalid_request", "State parameter is missing") + return + + state = request.args[b"state"][0].decode() + + # Deserialize the session token and verify it. + try: + ( + nonce, + client_redirect_url, + ui_auth_session_id, + ) = self._verify_oidc_session_token(session, state) + except MacaroonDeserializationException as e: + logger.exception("Invalid session") + self._render_error(request, "invalid_session", str(e)) + return + except MacaroonInvalidSignatureException as e: + logger.exception("Could not verify session") + self._render_error(request, "mismatching_session", str(e)) + return + + # Exchange the code with the provider + if b"code" not in request.args: + logger.info("Code parameter is missing") + self._render_error(request, "invalid_request", "Code parameter is missing") + return + + logger.debug("Exchanging code") + code = request.args[b"code"][0].decode() + try: + token = await self._exchange_code(code) + except OidcError as e: + logger.exception("Could not exchange code") + self._render_error(request, e.error, e.error_description) + return + + logger.debug("Successfully obtained OAuth2 access token") + + # Now that we have a token, get the userinfo, either by decoding the + # `id_token` or by fetching the `userinfo_endpoint`. + if self._uses_userinfo: + logger.debug("Fetching userinfo") + try: + userinfo = await self._fetch_userinfo(token) + except Exception as e: + logger.exception("Could not fetch userinfo") + self._render_error(request, "fetch_error", str(e)) + return + else: + logger.debug("Extracting userinfo from id_token") + try: + userinfo = await self._parse_id_token(token, nonce=nonce) + except Exception as e: + logger.exception("Invalid id_token") + self._render_error(request, "invalid_token", str(e)) + return + + # Call the mapper to register/login the user + try: + user_id = await self._map_userinfo_to_user(userinfo, token) + except MappingException as e: + logger.exception("Could not map user") + self._render_error(request, "mapping_error", str(e)) + return + + # and finally complete the login + if ui_auth_session_id: + await self._auth_handler.complete_sso_ui_auth( + user_id, ui_auth_session_id, request + ) + else: + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url + ) + + def _generate_oidc_session_token( + self, + state: str, + nonce: str, + client_redirect_url: str, + ui_auth_session_id: Optional[str], + duration_in_ms: int = (60 * 60 * 1000), + ) -> str: + """Generates a signed token storing data about an OIDC session. + + When Synapse initiates an authorization flow, it creates a random state + and a random nonce. Those parameters are given to the provider and + should be verified when the client comes back from the provider. + It is also used to store the client_redirect_url, which is used to + complete the SSO login flow. + + Args: + state: The ``state`` parameter passed to the OIDC provider. + nonce: The ``nonce`` parameter passed to the OIDC provider. + client_redirect_url: The URL the client gave when it initiated the + flow. + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + duration_in_ms: An optional duration for the token in milliseconds. + Defaults to an hour. + + Returns: + A signed macaroon token with the session informations. + """ + macaroon = pymacaroons.Macaroon( + location=self._server_name, identifier="key", key=self._macaroon_secret_key, + ) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = session") + macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("nonce = %s" % (nonce,)) + macaroon.add_first_party_caveat( + "client_redirect_url = %s" % (client_redirect_url,) + ) + if ui_auth_session_id: + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (ui_auth_session_id,) + ) + now = self._clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + + return macaroon.serialize() + + def _verify_oidc_session_token( + self, session: bytes, state: str + ) -> Tuple[str, str, Optional[str]]: + """Verifies and extract an OIDC session token. + + This verifies that a given session token was issued by this homeserver + and extract the nonce and client_redirect_url caveats. + + Args: + session: The session token to verify + state: The state the OIDC provider gave back + + Returns: + The nonce, client_redirect_url, and ui_auth_session_id for this session + """ + macaroon = pymacaroons.Macaroon.deserialize(session) + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = session") + v.satisfy_exact("state = %s" % (state,)) + v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + # Sometimes there's a UI auth session ID, it seems to be OK to attempt + # to always satisfy this. + v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) + v.satisfy_general(self._verify_expiry) + + v.verify(macaroon, self._macaroon_secret_key) + + # Extract the `nonce`, `client_redirect_url`, and maybe the + # `ui_auth_session_id` from the token. + nonce = self._get_value_from_macaroon(macaroon, "nonce") + client_redirect_url = self._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + try: + ui_auth_session_id = self._get_value_from_macaroon( + macaroon, "ui_auth_session_id" + ) # type: Optional[str] + except ValueError: + ui_auth_session_id = None + + return nonce, client_redirect_url, ui_auth_session_id + + def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + Exception: if the caveat was not in the macaroon + """ + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + def _verify_expiry(self, caveat: str) -> bool: + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + now = self._clock.time_msec() + return now < expiry + + async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + """Maps a UserInfo object to a mxid. + + UserInfo should have a claim that uniquely identifies users. This claim + is usually `sub`, but can be configured with `oidc_config.subject_claim`. + It is then used as an `external_id`. + + If we don't find the user that way, we should register the user, + mapping the localpart and the display name from the UserInfo. + + If a user already exists with the mxid we've mapped, raise an exception. + + Args: + userinfo: an object representing the user + token: a dict with the tokens obtained from the provider + + Raises: + MappingException: if there was an error while mapping some properties + + Returns: + The mxid of the user + """ + try: + remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo) + except Exception as e: + raise MappingException( + "Failed to extract subject from OIDC response: %s" % (e,) + ) + + logger.info( + "Looking for existing mapping for user %s:%s", + self._auth_provider_id, + remote_user_id, + ) + + registered_user_id = await self._datastore.get_user_by_external_id( + self._auth_provider_id, remote_user_id, + ) + + if registered_user_id is not None: + logger.info("Found existing mapping %s", registered_user_id) + return registered_user_id + + try: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token + ) + except Exception as e: + raise MappingException( + "Could not extract user attributes from OIDC response: " + str(e) + ) + + logger.debug( + "Retrieved user attributes from user mapping provider: %r", attributes + ) + + if not attributes["localpart"]: + raise MappingException("localpart is empty") + + localpart = map_username_to_mxid_localpart(attributes["localpart"]) + + user_id = UserID(localpart, self._hostname) + if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): + # This mxid is taken + raise MappingException( + "mxid '{}' is already taken".format(user_id.to_string()) + ) + + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=attributes["display_name"], + ) + + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id, + ) + return registered_user_id + + +UserAttribute = TypedDict( + "UserAttribute", {"localpart": str, "display_name": Optional[str]} +) +C = TypeVar("C") + + +class OidcMappingProvider(Generic[C]): + """A mapping provider maps a UserInfo object to user attributes. + + It should provide the API described by this class. + """ + + def __init__(self, config: C): + """ + Args: + config: A custom config object from this module, parsed by ``parse_config()`` + """ + + @staticmethod + def parse_config(config: dict) -> C: + """Parse the dict provided by the homeserver's config + + Args: + config: A dictionary containing configuration options for this provider + + Returns: + A custom config object for this module + """ + raise NotImplementedError() + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + """Get a unique user ID for this user. + + Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object. + + Args: + userinfo: An object representing the user given by the OIDC provider + + Returns: + A unique user ID + """ + raise NotImplementedError() + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: + """Map a ``UserInfo`` objects into user attributes. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing the ``localpart`` and (optionally) the ``display_name`` + """ + raise NotImplementedError() + + +# Used to clear out "None" values in templates +def jinja_finalize(thing): + return thing if thing is not None else "" + + +env = Environment(finalize=jinja_finalize) + + +@attr.s +class JinjaOidcMappingConfig: + subject_claim = attr.ib() # type: str + localpart_template = attr.ib() # type: Template + display_name_template = attr.ib() # type: Optional[Template] + + +class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): + """An implementation of a mapping provider based on Jinja templates. + + This is the default mapping provider. + """ + + def __init__(self, config: JinjaOidcMappingConfig): + self._config = config + + @staticmethod + def parse_config(config: dict) -> JinjaOidcMappingConfig: + subject_claim = config.get("subject_claim", "sub") + + if "localpart_template" not in config: + raise ConfigError( + "missing key: oidc_config.user_mapping_provider.config.localpart_template" + ) + + try: + localpart_template = env.from_string(config["localpart_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r" + % (e,) + ) + + display_name_template = None # type: Optional[Template] + if "display_name_template" in config: + try: + display_name_template = env.from_string(config["display_name_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" + % (e,) + ) + + return JinjaOidcMappingConfig( + subject_claim=subject_claim, + localpart_template=localpart_template, + display_name_template=display_name_template, + ) + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + return userinfo[self._config.subject_claim] + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: + localpart = self._config.localpart_template.render(user=userinfo).strip() + + display_name = None # type: Optional[str] + if self._config.display_name_template is not None: + display_name = self._config.display_name_template.render( + user=userinfo + ).strip() + + if display_name == "": + display_name = None + + return UserAttribute(localpart=localpart, display_name=display_name) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5744f4579d..d7442c62a7 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,12 +15,15 @@ # limitations under the License. import logging +from six import iteritems + from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock @@ -69,6 +72,8 @@ class PaginationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname @@ -78,6 +83,122 @@ class PaginationHandler(object): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() + self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + + if hs.config.retention_enabled: + # Run the purge jobs described in the configuration file. + for job in hs.config.retention_purge_jobs: + logger.info("Setting up purge job with config: %s", job) + + self.clock.looping_call( + run_as_background_process, + job["interval"], + "purge_history_for_rooms_in_range", + self.purge_history_for_rooms_in_range, + job["shortest_max_lifetime"], + job["longest_max_lifetime"], + ) + + @defer.inlineCallbacks + def purge_history_for_rooms_in_range(self, min_ms, max_ms): + """Purge outdated events from rooms within the given retention range. + + If a default retention policy is defined in the server's configuration and its + 'max_lifetime' is within this range, also targets rooms which don't have a + retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, it means that the range has no + lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, it means that the range has no + upper limit. + """ + # We want the storage layer to to include rooms with no retention policy in its + # return value only if a default retention policy is defined in the server's + # configuration and that policy's 'max_lifetime' is either lower (or equal) than + # max_ms or higher than min_ms (or both). + if self._retention_default_max_lifetime is not None: + include_null = True + + if min_ms is not None and min_ms >= self._retention_default_max_lifetime: + # The default max_lifetime is lower than (or equal to) min_ms. + include_null = False + + if max_ms is not None and max_ms < self._retention_default_max_lifetime: + # The default max_lifetime is higher than max_ms. + include_null = False + else: + include_null = False + + logger.info( + "[purge] Running purge job for %s < max_lifetime <= %s (include NULLs = %s)", + min_ms, + max_ms, + include_null, + ) + + rooms = yield self.store.get_rooms_for_retention_period_in_range( + min_ms, max_ms, include_null + ) + + logger.debug("[purge] Rooms to purge: %s", rooms) + + for room_id, retention_policy in iteritems(rooms): + logger.info("[purge] Attempting to purge messages in room %s", room_id) + + if room_id in self._purges_in_progress_by_room: + logger.warning( + "[purge] not purging room %s as there's an ongoing purge running" + " for this room", + room_id, + ) + continue + + max_lifetime = retention_policy["max_lifetime"] + + if max_lifetime is None: + # If max_lifetime is None, it means that include_null equals True, + # therefore we can safely assume that there is a default policy defined + # in the server's configuration. + max_lifetime = self._retention_default_max_lifetime + + # Figure out what token we should start purging at. + ts = self.clock.time_msec() - max_lifetime + + stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) + + r = yield self.store.get_room_event_before_stream_ordering( + room_id, stream_ordering, + ) + if not r: + logger.warning( + "[purge] purging events not possible: No event found " + "(ts %i => stream_ordering %i)", + ts, + stream_ordering, + ) + continue + + (stream, topo, _event_id) = r + token = "t%d-%d" % (topo, stream) + + purge_id = random_string(16) + + self._purges_by_id[purge_id] = PurgeStatus() + + logger.info( + "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id) + ) + + # We want to purge everything, including local events, and to run the purge in + # the background so that it's not blocking any other operation apart from + # other purges in the same room. + run_as_background_process( + "_purge_history", self._purge_history, purge_id, room_id, token, True, + ) + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. @@ -125,7 +246,9 @@ class PaginationHandler(object): self._purges_in_progress_by_room.add(room_id) try: with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history(room_id, token, delete_local_events) + yield self.storage.purge_events.purge_history( + room_id, token, delete_local_events + ) logger.info("[purge] complete") self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE except Exception: @@ -158,7 +281,7 @@ class PaginationHandler(object): """Purge the given room from the database""" with (await self.pagination_lock.write(room_id)): # check we know about the room - await self.store.get_room_version(room_id) + await self.store.get_room_version_id(room_id) # first check that we have no users in this room joined = await defer.maybeDeferred( @@ -168,10 +291,9 @@ class PaginationHandler(object): if joined: raise SynapseError(400, "Users are still joined to this room") - await self.store.purge_room(room_id) + await self.storage.purge_events.purge_room(room_id) - @defer.inlineCallbacks - def get_messages( + async def get_messages( self, requester, room_id=None, @@ -197,7 +319,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token_for_pagination() + await self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key @@ -209,9 +331,12 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (yield self.pagination_lock.read(room_id)): - membership, member_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id + with (await self.pagination_lock.read(room_id)): + ( + membership, + member_event_id, + ) = await self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True ) if source_config.direction == "b": @@ -220,7 +345,7 @@ class PaginationHandler(object): if room_token.topological: max_topo = room_token.topological else: - max_topo = yield self.store.get_max_topological_token( + max_topo = await self.store.get_max_topological_token( room_id, room_token.stream ) @@ -228,18 +353,18 @@ class PaginationHandler(object): # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. - leave_token = yield self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) leave_token = RoomStreamToken.parse(leave_token) if leave_token.topological < max_topo: source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( + await self.hs.get_handlers().federation_handler.maybe_backfill( room_id, max_topo ) - events, next_key = yield self.store.paginate_room_events( + events, next_key = await self.store.paginate_room_events( room_id=room_id, from_key=source_config.from_key, to_key=source_config.to_key, @@ -254,8 +379,8 @@ class PaginationHandler(object): if event_filter: events = event_filter.filter(events) - events = yield filter_events_for_client( - self.store, user_id, events, is_peeking=(member_event_id is None) + events = await filter_events_for_client( + self.storage, user_id, events, is_peeking=(member_event_id is None) ) if not events: @@ -274,19 +399,19 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) state = state.values() time_now = self.clock.time_msec() chunk = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event ) ), @@ -295,10 +420,8 @@ class PaginationHandler(object): } if state: - chunk["state"] = ( - yield self._event_serializer.serialize_events( - state, time_now, as_client_event=as_client_event - ) + chunk["state"] = await self._event_serializer.serialize_events( + state, time_now, as_client_event=as_client_event ) return chunk diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py new file mode 100644 index 0000000000..d06b110269 --- /dev/null +++ b/synapse/handlers/password_policy.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.errors import Codes, PasswordRefusedError + +logger = logging.getLogger(__name__) + + +class PasswordPolicyHandler(object): + def __init__(self, hs): + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + # Regexps for the spec'd policy parameters. + self.regexp_digit = re.compile("[0-9]") + self.regexp_symbol = re.compile("[^a-zA-Z0-9]") + self.regexp_uppercase = re.compile("[A-Z]") + self.regexp_lowercase = re.compile("[a-z]") + + def validate_password(self, password): + """Checks whether a given password complies with the server's policy. + + Args: + password (str): The password to check against the server's policy. + + Raises: + PasswordRefusedError: The password doesn't comply with the server's policy. + """ + + if not self.enabled: + return + + minimum_accepted_length = self.policy.get("minimum_length", 0) + if len(password) < minimum_accepted_length: + raise PasswordRefusedError( + msg=( + "The password must be at least %d characters long" + % minimum_accepted_length + ), + errcode=Codes.PASSWORD_TOO_SHORT, + ) + + if ( + self.policy.get("require_digit", False) + and self.regexp_digit.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one digit", + errcode=Codes.PASSWORD_NO_DIGIT, + ) + + if ( + self.policy.get("require_symbol", False) + and self.regexp_symbol.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one symbol", + errcode=Codes.PASSWORD_NO_SYMBOL, + ) + + if ( + self.policy.get("require_uppercase", False) + and self.regexp_uppercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one uppercase letter", + errcode=Codes.PASSWORD_NO_UPPERCASE, + ) + + if ( + self.policy.get("require_lowercase", False) + and self.regexp_lowercase.search(password) is None + ): + raise PasswordRefusedError( + msg="The password must include at least one lowercase letter", + errcode=Codes.PASSWORD_NO_LOWERCASE, + ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 053cf66b28..3594f3b00f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,13 +22,15 @@ The methods that define policy are: - PresenceHandler._handle_timeouts - should_notify """ - +import abc import logging from contextlib import contextmanager +from typing import Dict, Iterable, List, Set from six import iteritems, itervalues from prometheus_client import Counter +from typing_extensions import ContextManager from twisted.internet import defer @@ -39,12 +42,16 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -93,35 +100,122 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -class PresenceHandler(object): - def __init__(self, hs): - """ +class BasePresenceHandler(abc.ABC): + """Parts of the PresenceHandler that are shared between workers and master""" + + def __init__(self, hs: "synapse.server.HomeServer"): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + active_presence = self.store.take_presence_startup_info() + self.user_to_current_state = {state.user_id: state for state in active_presence} + + @abc.abstractmethod + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Returns a context manager that should surround any stream requests + from the user. + + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. Args: - hs (synapse.server.HomeServer): + user_id: the user that is starting a sync + affect_presence: If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + + @abc.abstractmethod + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + """Get an iterable of syncing users on this worker, to send to the presence handler + + This is called when a replication connection is established. It should return + a list of user ids, which are then sent as USER_SYNC commands to inform the + process handling presence about those users. + + Returns: + An iterable of user_id strings. """ + + async def get_state(self, target_user: UserID) -> UserPresenceState: + results = await self.get_states([target_user.to_string()]) + return results[0] + + async def get_states( + self, target_user_ids: Iterable[str] + ) -> List[UserPresenceState]: + """Get the presence state for users.""" + + updates_d = await self.current_state_for_users(target_user_ids) + updates = list(updates_d.values()) + + for user_id in set(target_user_ids) - {u.user_id for u in updates}: + updates.append(UserPresenceState.default(user_id)) + + return updates + + async def current_state_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: + """Get the current presence state for multiple users. + + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = await self.store.get_presence_for_users(missing) + states.update(res) + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + new = { + user_id: UserPresenceState.default(user_id) for user_id in missing + } + states.update(new) + self.user_to_current_state.update(new) + + return states + + @abc.abstractmethod + async def set_state( + self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + ) -> None: + """Set the presence state of the user. """ + + @abc.abstractmethod + async def bump_presence_active_time(self, user: UserID): + """We've seen the user do something that indicates they're interacting + with the app. + """ + + +class PresenceHandler(BasePresenceHandler): + def __init__(self, hs: "synapse.server.HomeServer"): + super().__init__(hs) self.hs = hs - self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname - self.clock = hs.get_clock() - self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() + self._presence_enabled = hs.config.use_presence federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler("m.presence", self.incoming_presence) - active_presence = self.store.take_presence_startup_info() - - # A dictionary of the current state of users. This is prefilled with - # non-offline presence from the DB. We should fetch from the DB if - # we can't find a users presence in here. - self.user_to_current_state = {state.user_id: state for state in active_presence} - LaterGauge( "synapse_handlers_presence_user_to_current_state_size", "", @@ -130,7 +224,7 @@ class PresenceHandler(object): ) now = self.clock.time_msec() - for state in active_presence: + for state in self.user_to_current_state.values(): self.wheel_timer.insert( now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) @@ -154,7 +248,7 @@ class PresenceHandler(object): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() + self.unpersisted_users_changes = set() # type: Set[str] hs.get_reactor().addSystemEventTrigger( "before", @@ -164,12 +258,11 @@ class PresenceHandler(object): self._on_shutdown, ) - self.serial_to_user = {} self._next_serial = 1 # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} + self.user_to_num_current_syncs = {} # type: Dict[str, int] # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -179,8 +272,9 @@ class PresenceHandler(object): # we assume that all the sync requests on that process have stopped. # Stored as a dict from process_id to set of user_id, and a dict of # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs = {} - self.external_process_last_updated_ms = {} + self.external_process_to_current_syncs = {} # type: Dict[int, Set[str]] + self.external_process_last_updated_ms = {} # type: Dict[int, int] + self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") # Start a LoopingCall in 30s that fires every 5s. @@ -216,8 +310,7 @@ class PresenceHandler(object): self._event_pos = self.store.get_current_events_token() self._event_processing = False - @defer.inlineCallbacks - def _on_shutdown(self): + async def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -228,7 +321,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.store.db.is_running(): return logger.info( @@ -238,7 +331,7 @@ class PresenceHandler(object): if self.unpersisted_users_changes: - yield self.store.update_presence( + await self.store.update_presence( [ self.user_to_current_state[user_id] for user_id in self.unpersisted_users_changes @@ -246,8 +339,7 @@ class PresenceHandler(object): ) logger.info("Finished _on_shutdown") - @defer.inlineCallbacks - def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self): """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -256,12 +348,11 @@ class PresenceHandler(object): if unpersisted: logger.info("Persisting %d unpersisted presence updates", len(unpersisted)) - yield self.store.update_presence( + await self.store.update_presence( [self.user_to_current_state[user_id] for user_id in unpersisted] ) - @defer.inlineCallbacks - def _update_states(self, new_states): + async def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. @@ -270,7 +361,7 @@ class PresenceHandler(object): with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't yield between now and when we've + # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. to_notify = {} # Changes we want to notify everyone about @@ -314,9 +405,9 @@ class PresenceHandler(object): if to_notify: notified_presence_counter.inc(len(to_notify)) - yield self._persist_and_notify(list(to_notify.values())) + await self._persist_and_notify(list(to_notify.values())) - self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { @@ -329,7 +420,7 @@ class PresenceHandler(object): self._push_to_remotes(to_federation_ping.values()) - def _handle_timeouts(self): + async def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as appropriate. """ @@ -349,10 +440,13 @@ class PresenceHandler(object): if now - last_update > EXTERNAL_PROCESS_EXPIRY ] for process_id in expired_process_ids: + # For each expired process drop tracking info and check the users + # that were syncing on that process to see if they need to be timed + # out. users_to_check.update( - self.external_process_last_updated_ms.pop(process_id, ()) + self.external_process_to_current_syncs.pop(process_id, ()) ) - self.external_process_last_update.pop(process_id) + self.external_process_last_updated_ms.pop(process_id) states = [ self.user_to_current_state.get(user_id, UserPresenceState.default(user_id)) @@ -361,17 +455,24 @@ class PresenceHandler(object): timers_fired_counter.inc(len(states)) + syncing_user_ids = { + user_id + for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - syncing_user_ids=self.get_currently_syncing_users(), + syncing_user_ids=syncing_user_ids, now=now, ) - return self._update_states(changes) + return await self._update_states(changes) - @defer.inlineCallbacks - def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user): """We've seen the user do something that indicates they're interacting with the app. """ @@ -383,16 +484,17 @@ class PresenceHandler(object): bump_active_time_counter.inc() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def user_syncing(self, user_id, affect_presence=True): + async def user_syncing( + self, user_id: str, affect_presence: bool = True + ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -415,11 +517,11 @@ class PresenceHandler(object): curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( state=PresenceState.ONLINE, @@ -429,7 +531,7 @@ class PresenceHandler(object): ] ) else: - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -437,13 +539,12 @@ class PresenceHandler(object): ] ) - @defer.inlineCallbacks - def _end(): + async def _end(): try: self.user_to_num_current_syncs[user_id] -= 1 - prev_state = yield self.current_state_for_user(user_id) - yield self._update_states( + prev_state = await self.current_state_for_user(user_id) + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -463,25 +564,11 @@ class PresenceHandler(object): return _user_syncing() - def get_currently_syncing_users(self): - """Get the set of user ids that are currently syncing on this HS. - Returns: - set(str): A set of user_id strings. - """ - if self.hs.config.use_presence: - syncing_user_ids = { - user_id - for user_id, count in self.user_to_num_current_syncs.items() - if count - } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) - return syncing_user_ids - else: - return set() + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + # since we are the process handling presence, there is nothing to do here. + return [] - @defer.inlineCallbacks - def update_external_syncs_row( + async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec ): """Update the syncing users for an external process as a delta. @@ -494,8 +581,8 @@ class PresenceHandler(object): is_syncing (bool): Whether or not the user is now syncing sync_time_msec(int): Time in ms when the user was last syncing """ - with (yield self.external_sync_linearizer.queue(process_id)): - prev_state = yield self.current_state_for_user(user_id) + with (await self.external_sync_linearizer.queue(process_id)): + prev_state = await self.current_state_for_user(user_id) process_presence = self.external_process_to_current_syncs.setdefault( process_id, set() @@ -525,25 +612,24 @@ class PresenceHandler(object): process_presence.discard(user_id) if updates: - yield self._update_states(updates) + await self._update_states(updates) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - @defer.inlineCallbacks - def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id): """Marks all users that had been marked as syncing by a given process as offline. Used when the process has stopped/disappeared. """ - with (yield self.external_sync_linearizer.queue(process_id)): + with (await self.external_sync_linearizer.queue(process_id)): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = yield self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) for prev_state in itervalues(prev_states) @@ -551,50 +637,19 @@ class PresenceHandler(object): ) self.external_process_last_updated_ms.pop(process_id, None) - @defer.inlineCallbacks - def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id): """Get the current presence state for a user. """ - res = yield self.current_state_for_users([user_id]) + res = await self.current_state_for_users([user_id]) return res[user_id] - @defer.inlineCallbacks - def current_state_for_users(self, user_ids): - """Get the current presence state for multiple users. - - Returns: - dict: `user_id` -> `UserPresenceState` - """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - # There are things not in our in memory cache. Lets pull them out of - # the database. - res = yield self.store.get_presence_for_users(missing) - states.update(res) - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) - - return states - - @defer.inlineCallbacks - def _persist_and_notify(self, states): + async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers """ - stream_id, max_token = yield self.store.update_presence(states) + stream_id, max_token = await self.store.update_presence(states) - parties = yield get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -606,9 +661,8 @@ class PresenceHandler(object): self._push_to_remotes(states) - @defer.inlineCallbacks - def notify_for_states(self, state, stream_id): - parties = yield get_interested_parties(self.store, [state]) + async def notify_for_states(self, state, stream_id): + parties = await get_interested_parties(self.store, [state]) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -626,17 +680,17 @@ class PresenceHandler(object): """ self.federation.send_presence(states) - @defer.inlineCallbacks - def incoming_presence(self, origin, content): + async def incoming_presence(self, origin, content): """Called when we receive a `m.presence` EDU from a remote server. """ + if not self._presence_enabled: + return + now = self.clock.time_msec() updates = [] for push in content.get("push", []): # A "push" contains a list of presence that we are probably interested # in. - # TODO: Actually check if we're interested, rather than blindly - # accepting presence updates. user_id = push.get("user_id", None) if not user_id: logger.info( @@ -670,51 +724,14 @@ class PresenceHandler(object): new_fields["status_msg"] = push.get("status_msg", None) new_fields["currently_active"] = push.get("currently_active", False) - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) updates.append(prev_state.copy_and_replace(**new_fields)) if updates: federation_presence_counter.inc(len(updates)) - yield self._update_states(updates) - - @defer.inlineCallbacks - def get_state(self, target_user, as_event=False): - results = yield self.get_states([target_user.to_string()], as_event=as_event) - - return results[0] - - @defer.inlineCallbacks - def get_states(self, target_user_ids, as_event=False): - """Get the presence state for users. + await self._update_states(updates) - Args: - target_user_ids (list) - as_event (bool): Whether to format it as a client event or not. - - Returns: - list - """ - - updates = yield self.current_state_for_users(target_user_ids) - updates = list(updates.values()) - - for user_id in set(target_user_ids) - set(u.user_id for u in updates): - updates.append(UserPresenceState.default(user_id)) - - now = self.clock.time_msec() - if as_event: - return [ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ] - else: - return updates - - @defer.inlineCallbacks - def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -730,7 +747,7 @@ class PresenceHandler(object): user_id = target_user.to_string() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"state": presence} @@ -741,16 +758,15 @@ class PresenceHandler(object): if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. """ - observer_room_ids = yield self.store.get_rooms_for_user( + observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) - observed_room_ids = yield self.store.get_rooms_for_user( + observed_room_ids = await self.store.get_rooms_for_user( observed_user.to_string() ) @@ -759,8 +775,7 @@ class PresenceHandler(object): return False - @defer.inlineCallbacks - def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id, limit): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -775,7 +790,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = yield self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id, limit) return rows def notify_new_event(self): @@ -786,38 +801,43 @@ class PresenceHandler(object): if self._event_processing: return - @defer.inlineCallbacks - def _process_presence(): + async def _process_presence(): assert not self._event_processing self._event_processing = True try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._event_processing = False run_as_background_process("presence.notify_new_event", _process_presence) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): - deltas = yield self.store.get_current_state_deltas(self._event_pos) - if not deltas: + room_max_stream_ordering = self.store.get_room_max_stream_ordering() + if self._event_pos == room_max_stream_ordering: return - yield self._handle_state_delta(deltas) + logger.debug( + "Processing presence stats %s->%s", + self._event_pos, + room_max_stream_ordering, + ) + max_pos, deltas = await self.store.get_current_state_deltas( + self._event_pos, room_max_stream_ordering + ) + await self._handle_state_delta(deltas) - self._event_pos = deltas[-1]["stream_id"] + self._event_pos = max_pos # Expose current event processing position to prometheus synapse.metrics.event_processing_positions.labels("presence").set( - self._event_pos + max_pos ) - @defer.inlineCallbacks - def _handle_state_delta(self, deltas): + async def _handle_state_delta(self, deltas): """Process current state deltas to find new joins that need to be handled. """ @@ -838,13 +858,13 @@ class PresenceHandler(object): # joins. continue - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event or event.content.get("membership") != Membership.JOIN: # We only care about joins continue if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if ( prev_event and prev_event.content.get("membership") == Membership.JOIN @@ -852,10 +872,9 @@ class PresenceHandler(object): # Ignore changes to join events. continue - yield self._on_user_joined_room(room_id, state_key) + await self._on_user_joined_room(room_id, state_key) - @defer.inlineCallbacks - def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id, user_id): """Called when we detect a user joining the room via the current state delta stream. @@ -874,11 +893,11 @@ class PresenceHandler(object): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = yield self.current_state_for_user(user_id) - hosts = yield self.state.get_current_hosts_in_room(room_id) + state = await self.current_state_for_user(user_id) + hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. - hosts = set(host for host in hosts if host != self.server_name) + hosts = {host for host in hosts if host != self.server_name} self.federation.send_presence_to_destinations( states=[state], destinations=hosts @@ -895,10 +914,10 @@ class PresenceHandler(object): # TODO: Check that this is actually a new server joining the # room. - user_ids = yield self.state.get_current_users_in_room(room_id) + user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = yield self.current_state_for_users(user_ids) + states_d = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -908,7 +927,7 @@ class PresenceHandler(object): now = self.clock.time_msec() states = [ state - for state in states.values() + for state in states_d.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None @@ -988,9 +1007,8 @@ class PresenceEventSource(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - @defer.inlineCallbacks @log_function - def get_new_events( + async def get_new_events( self, user, from_key, @@ -1037,7 +1055,7 @@ class PresenceEventSource(object): presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - users_interested_in = yield self._get_interested_in(user, explicit_room_id) + users_interested_in = await self._get_interested_in(user, explicit_room_id) user_ids_changed = set() changed = None @@ -1063,7 +1081,7 @@ class PresenceEventSource(object): else: user_ids_changed = users_interested_in - updates = yield presence.current_state_for_users(user_ids_changed) + updates = await presence.current_state_for_users(user_ids_changed) if include_offline: return (list(updates.values()), max_token) @@ -1076,11 +1094,11 @@ class PresenceEventSource(object): def get_current_key(self): return self.store.get_current_presence_token() - def get_pagination_rows(self, user, pagination_config, key): - return self.get_new_events(user, from_key=None, include_offline=False) + async def get_pagination_rows(self, user, pagination_config, key): + return await self.get_new_events(user, from_key=None, include_offline=False) - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _get_interested_in(self, user, explicit_room_id, cache_context): + @cached(num_args=2, cache_context=True) + async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence updates for """ @@ -1088,13 +1106,13 @@ class PresenceEventSource(object): users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: - user_ids = yield self.store.get_users_in_room( + user_ids = await self.store.get_users_in_room( explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1269,8 +1287,8 @@ def get_interested_parties(store, states): 2-tuple: `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} - users_to_states = {} + room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] + users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: room_ids = yield store.get_rooms_for_user(state.user_id) for room_id in room_ids: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 8690f69d45..302efc1b9a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -28,7 +28,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, get_domain_from_id +from synapse.types import UserID, create_requester, get_domain_from_id from ._base import BaseHandler @@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler): return result["displayname"] - @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname, by_admin=False): + async def set_displayname( + self, target_user, requester, new_displayname, by_admin=False + ): """Set the displayname of a user Args: @@ -152,11 +153,20 @@ class BaseProfileHandler(BaseHandler): by_admin (bool): Whether this change was made by an administrator. """ if not self.hs.is_mine(target_user): - raise SynapseError(400, "User is not hosted on this Home Server") + raise SynapseError(400, "User is not hosted on this homeserver") if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and not self.hs.config.enable_set_displayname: + profile = await self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError( + 400, + "Changing display name is disabled on this server", + Codes.FORBIDDEN, + ) + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -165,15 +175,21 @@ class BaseProfileHandler(BaseHandler): if new_displayname == "": new_displayname = None - yield self.store.set_profile_displayname(target_user.localpart, new_displayname) + # If the admin changes the display name of a user, the requesting user cannot send + # the join event to update the displayname in the rooms. + # This must be done by the target user himself. + if by_admin: + requester = create_requester(target_user) + + await self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -202,36 +218,48 @@ class BaseProfileHandler(BaseHandler): return result["avatar_url"] - @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): + async def set_avatar_url( + self, target_user, requester, new_avatar_url, by_admin=False + ): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): - raise SynapseError(400, "User is not hosted on this Home Server") + raise SynapseError(400, "User is not hosted on this homeserver") if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and not self.hs.config.enable_set_avatar_url: + profile = await self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError( + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN + ) + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) - yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + # Same like set_displayname + if by_admin: + requester = create_requester(target_user) + + await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): - raise SynapseError(400, "User is not hosted on this Home Server") + raise SynapseError(400, "User is not hosted on this homeserver") just_field = args.get("field", None) @@ -253,21 +281,20 @@ class BaseProfileHandler(BaseHandler): return response - @defer.inlineCallbacks - def _update_join_states(self, requester, target_user): + async def _update_join_states(self, requester, target_user): if not self.hs.is_mine(target_user): return - yield self.ratelimit(requester) + await self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() try: # Assume the target_user isn't a guest, # because we don't let guests set profile or avatar data. - yield handler.update_membership( + await handler.update_membership( requester, target_user, room_id, @@ -275,7 +302,7 @@ class BaseProfileHandler(BaseHandler): ratelimit=False, # Try to hide that these events aren't atomic. ) except Exception as e: - logger.warn( + logger.warning( "Failed to update join event for room %s - %s", room_id, str(e) ) @@ -295,12 +322,16 @@ class BaseProfileHandler(BaseHandler): be found to be in any room the server is in, and therefore the query is denied. """ + # Implementation of MSC1301: don't allow looking up profiles if the # requester isn't in the same room as the target. We expect requester to # be None when this function is called outside of a profile query, e.g. # when building a membership event. In this case, we must allow the # lookup. - if not self.hs.config.require_auth_for_profile_requests or not requester: + if ( + not self.hs.config.limit_profile_requests_to_users_who_share_rooms + or not requester + ): return # Always allow the user to query their own profile. diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 3e4d8c93a4..e3b528d271 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler): self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def received_client_read_marker(self, room_id, user_id, event_id): + async def received_client_read_marker(self, room_id, user_id, event_id): """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. @@ -41,8 +38,8 @@ class ReadMarkerHandler(BaseHandler): the read marker has changed. """ - with (yield self.read_marker_linearizer.queue((room_id, user_id))): - existing_read_marker = yield self.store.get_account_data_for_room_and_type( + with await self.read_marker_linearizer.queue((room_id, user_id)): + existing_read_marker = await self.store.get_account_data_for_room_and_type( user_id, room_id, "m.fully_read" ) @@ -50,13 +47,13 @@ class ReadMarkerHandler(BaseHandler): if existing_read_marker: # Only update if the new marker is ahead in the stream - should_update = yield self.store.is_event_after( + should_update = await self.store.is_event_after( event_id, existing_read_marker["event_id"] ) if should_update: content = {"event_id": event_id} - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6854c751a6..8bc100db42 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.handlers._base import BaseHandler from synapse.types import ReadReceipt, get_domain_from_id +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -36,8 +37,7 @@ class ReceiptsHandler(BaseHandler): self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - @defer.inlineCallbacks - def _received_remote_receipt(self, origin, content): + async def _received_remote_receipt(self, origin, content): """Called when we receive an EDU of type m.receipt from a remote HS. """ receipts = [] @@ -62,17 +62,16 @@ class ReceiptsHandler(BaseHandler): ) ) - yield self._handle_new_receipts(receipts) + await self._handle_new_receipts(receipts) - @defer.inlineCallbacks - def _handle_new_receipts(self, receipts): + async def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ min_batch_id = None max_batch_id = None for receipt in receipts: - res = yield self.store.insert_receipt( + res = await self.store.insert_receipt( receipt.room_id, receipt.receipt_type, receipt.user_id, @@ -95,18 +94,19 @@ class ReceiptsHandler(BaseHandler): # no new receipts return False - affected_room_ids = list(set([r.room_id for r in receipts])) + affected_room_ids = list({r.room_id for r in receipts}) self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids + await maybe_awaitable( + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) ) return True - @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, event_id): + async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -118,24 +118,11 @@ class ReceiptsHandler(BaseHandler): data={"ts": int(self.clock.time_msec())}, ) - is_new = yield self._handle_new_receipts([receipt]) + is_new = await self._handle_new_receipts([receipt]) if not is_new: return - yield self.federation.send_read_receipt(receipt) - - @defer.inlineCallbacks - def get_receipts_for_room(self, room_id, to_key): - """Gets all receipts for a room, upto the given key. - """ - result = yield self.store.get_linearized_receipts_for_room( - room_id, to_key=to_key - ) - - if not result: - return [] - - return result + await self.federation.send_read_receipt(receipt) class ReceiptEventSource(object): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 06bd03b77c..51979ea43e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -16,18 +16,9 @@ """Contains functions for registering clients.""" import logging -from twisted.internet import defer - from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, LoginType -from synapse.api.errors import ( - AuthError, - Codes, - ConsentNotGivenError, - LimitExceededError, - RegistrationError, - SynapseError, -) +from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet @@ -82,8 +73,9 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime - @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): + async def check_username( + self, localpart, guest_access_token=None, assigned_user_id=None + ): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, @@ -120,13 +112,13 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME, ) - users = yield self.store.get_users_by_id_case_insensitive(user_id) + users = await self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - user_data = yield self.auth.get_user_by_access_token(guest_access_token) + user_data = await self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( 403, @@ -135,11 +127,19 @@ class RegistrationHandler(BaseHandler): errcode=Codes.FORBIDDEN, ) - @defer.inlineCallbacks - def register_user( + if guest_access_token is None: + try: + int(localpart) + raise SynapseError( + 400, "Numeric user IDs are reserved for guest users." + ) + except ValueError: + pass + + async def register_user( self, localpart=None, - password=None, + password_hash=None, guest_access_token=None, make_guest=False, admin=False, @@ -148,13 +148,14 @@ class RegistrationHandler(BaseHandler): default_display_name=None, address=None, bind_emails=[], + by_admin=False, ): """Registers a new client on the server. Args: - localpart : The local part of the user ID to register. If None, + localpart: The local part of the user ID to register. If None, one will be generated. - password (unicode) : The password to assign to this user so they can + password_hash (str|None): The hashed password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). user_type (str|None): type of user. One of the values from @@ -163,31 +164,24 @@ class RegistrationHandler(BaseHandler): will be set to this. Defaults to 'localpart'. address (str|None): the IP address used to perform the registration. bind_emails (List[str]): list of emails to bind to this account. + by_admin (bool): True if this registration is being made via the + admin api, otherwise False. Returns: - Deferred[str]: user_id + str: user_id Raises: - RegistrationError if there was a problem registering. + SynapseError if there was a problem registering. """ + self.check_registration_ratelimit(address) - yield self.auth.check_auth_blocking(threepid=threepid) - password_hash = None - if password: - password_hash = yield self._auth_handler.hash(password) + # do not check_auth_blocking if the call is coming through the Admin API + if not by_admin: + await self.auth.check_auth_blocking(threepid=threepid) - if localpart: - yield self.check_username(localpart, guest_access_token=guest_access_token) + if localpart is not None: + await self.check_username(localpart, guest_access_token=guest_access_token) was_guest = guest_access_token is not None - if not was_guest: - try: - int(localpart) - raise RegistrationError( - 400, "Numeric user IDs are reserved for guest users." - ) - except ValueError: - pass - user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -198,7 +192,7 @@ class RegistrationHandler(BaseHandler): elif default_display_name is None: default_display_name = localpart - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -210,38 +204,51 @@ class RegistrationHandler(BaseHandler): ) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(localpart) + await self.user_directory_handler.handle_local_profile_change( user_id, profile ) else: # autogen a sequential user ID - attempts = 0 + fail_count = 0 user = None while not user: - localpart = yield self._generate_user_id(attempts > 0) + # Fail after being unable to find a suitable ID a few times + if fail_count > 10: + raise SynapseError(500, "Unable to find a suitable guest user ID") + + localpart = await self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_not_appservice_exclusive(user_id) + self.check_user_id_not_appservice_exclusive(user_id) if default_display_name is None: default_display_name = localpart try: - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash=password_hash, make_guest=make_guest, create_profile_with_displayname=default_display_name, address=address, ) + + # Successfully registered + break except SynapseError: # if user id is taken, just generate another user = None user_id = None - attempts += 1 + fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield self._auto_join_rooms(user_id) + if not self.hs.config.auto_join_rooms_for_guests and make_guest: + logger.info( + "Skipping auto-join for %s because auto-join for guests is disabled", + user_id, + ) + else: + await self._auto_join_rooms(user_id) else: logger.info( "Skipping auto-join for %s because consent is required at registration", @@ -259,12 +266,11 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self._register_email_threepid(user_id, threepid_dict, None, False) + await self._register_email_threepid(user_id, threepid_dict, None) return user_id - @defer.inlineCallbacks - def _auto_join_rooms(self, user_id): + async def _auto_join_rooms(self, user_id): """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -278,9 +284,9 @@ class RegistrationHandler(BaseHandler): # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_real_user = yield self.store.is_real_user(user_id) + is_real_user = await self.store.is_real_user(user_id) if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = yield self.store.count_real_users() + count = await self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) @@ -299,7 +305,7 @@ class RegistrationHandler(BaseHandler): # getting the RoomCreationHandler during init gives a dependency # loop - yield self.hs.get_room_creation_handler().create_room( + await self.hs.get_room_creation_handler().create_room( fake_requester, config={ "preset": "public_chat", @@ -308,7 +314,7 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) else: - yield self._join_user_to_room(fake_requester, r) + await self._join_user_to_room(fake_requester, r) except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -316,18 +322,16 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - @defer.inlineCallbacks - def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted Args: user_id (str): The user to join """ - yield self._auto_join_rooms(user_id) + await self._auto_join_rooms(user_id) - @defer.inlineCallbacks - def appservice_register(self, user_localpart, as_token): + async def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -342,11 +346,9 @@ class RegistrationHandler(BaseHandler): service_id = service.id if service.is_exclusive_user(user_id) else None - yield self.check_user_id_not_appservice_exclusive( - user_id, allowed_appservice=service - ) + self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service) - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash="", appservice_id=service_id, @@ -378,28 +380,26 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - @defer.inlineCallbacks - def _generate_user_id(self, reseed=False): - if reseed or self._next_generated_user_id is None: - with (yield self._generate_user_id_linearizer.queue(())): - if reseed or self._next_generated_user_id is None: + async def _generate_user_id(self): + if self._next_generated_user_id is None: + with await self._generate_user_id_linearizer.queue(()): + if self._next_generated_user_id is None: self._next_generated_user_id = ( - yield self.store.find_next_generated_user_id_localpart() + await self.store.find_next_generated_user_id_localpart() ) id = self._next_generated_user_id self._next_generated_user_id += 1 return str(id) - @defer.inlineCallbacks - def _join_user_to_room(self, requester, room_identifier): + async def _join_user_to_room(self, requester, room_identifier): room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = ( - yield room_member_handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( + room_alias ) room_id = room_id.to_string() else: @@ -407,7 +407,7 @@ class RegistrationHandler(BaseHandler): 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield room_member_handler.update_membership( + await room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -416,6 +416,22 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) + def check_registration_ratelimit(self, address): + """A simple helper method to check whether the registration rate limit has been hit + for a given IP address + + Args: + address (str|None): the IP address used to perform the registration. If this is + None, no ratelimiting will be performed. + + Raises: + LimitExceededError: If the rate limit has been exceeded. + """ + if not address: + return + + self.ratelimiter.ratelimit(address) + def register_with_store( self, user_id, @@ -448,22 +464,6 @@ class RegistrationHandler(BaseHandler): Returns: Deferred """ - # Don't rate limit for app services - if appservice_id is None and address is not None: - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - address, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) - if self.hs.config.worker_app: return self._register_client( user_id=user_id, @@ -488,8 +488,9 @@ class RegistrationHandler(BaseHandler): user_type=user_type, ) - @defer.inlineCallbacks - def register_device(self, user_id, device_id, initial_display_name, is_guest=False): + async def register_device( + self, user_id, device_id, initial_display_name, is_guest=False + ): """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -503,11 +504,11 @@ class RegistrationHandler(BaseHandler): is_guest (bool): Whether this is a guest account Returns: - defer.Deferred[tuple[str, str]]: Tuple of device ID and access token + tuple[str, str]: Tuple of device ID and access token """ if self.hs.config.worker_app: - r = yield self._register_device_client( + r = await self._register_device_client( user_id=user_id, device_id=device_id, initial_display_name=initial_display_name, @@ -523,7 +524,7 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime - device_id = yield self.device_handler.check_device_registered( + device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) if is_guest: @@ -532,14 +533,13 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield self._auth_handler.get_access_token_for_user_id( + access_token = await self._auth_handler.get_access_token_for_user_id( user_id, device_id=device_id, valid_until_ms=valid_until_ms ) return (device_id, access_token) - @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions(self, user_id, auth_result, access_token): """A user has completed registration Args: @@ -550,7 +550,7 @@ class RegistrationHandler(BaseHandler): device, or None if `inhibit_login` enabled. """ if self.hs.config.worker_app: - yield self._post_registration_client( + await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token ) return @@ -562,19 +562,18 @@ class RegistrationHandler(BaseHandler): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(user_id) + await self.store.upsert_monthly_active_user(user_id) - yield self._register_email_threepid(user_id, threepid, access_token) + await self._register_email_threepid(user_id, threepid, access_token) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid(user_id, threepid) + await self._register_msisdn_threepid(user_id, threepid) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented(user_id, self.hs.config.user_consent_version) + await self._on_user_consented(user_id, self.hs.config.user_consent_version) - @defer.inlineCallbacks - def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id, consent_version): """A user consented to the terms on registration Args: @@ -583,11 +582,10 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version(user_id, consent_version) - yield self.post_consent_actions(user_id) + await self.store.user_set_consent_version(user_id, consent_version) + await self.post_consent_actions(user_id) - @defer.inlineCallbacks - def _register_email_threepid(self, user_id, threepid, token): + async def _register_email_threepid(self, user_id, threepid, token): """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the @@ -600,8 +598,6 @@ class RegistrationHandler(BaseHandler): threepid (object): m.login.email.identity auth response token (str|None): access_token for the user, or None if not logged in. - Returns: - defer.Deferred: """ reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): @@ -609,14 +605,14 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + await self._auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], threepid["validated_at"], ) # And we add an email pusher for them by default, but only # if email notifications are enabled (so people don't start # getting mail spam where they weren't before if email - # notifs are set up on a home server) + # notifs are set up on a homeserver) if ( self.hs.config.email_enable_notifs and self.hs.config.email_notif_for_new_users @@ -626,10 +622,10 @@ class RegistrationHandler(BaseHandler): # It would really make more sense for this to be passed # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token(token) + user_tuple = await self.store.get_user_by_access_token(token) token_id = user_tuple["token_id"] - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -641,8 +637,7 @@ class RegistrationHandler(BaseHandler): data={}, ) - @defer.inlineCallbacks - def _register_msisdn_threepid(self, user_id, threepid): + async def _register_msisdn_threepid(self, user_id, threepid): """Add a phone number as a 3pid identifier Must be called on master. @@ -650,8 +645,6 @@ class RegistrationHandler(BaseHandler): Args: user_id (str): id of user threepid (object): m.login.msisdn auth response - Returns: - defer.Deferred: """ try: assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) @@ -662,6 +655,6 @@ class RegistrationHandler(BaseHandler): return None raise - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + await self._auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], threepid["validated_at"], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 970be3c846..61db3ccc43 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,21 +16,31 @@ # limitations under the License. """Contains functions for performing events on rooms.""" + import itertools import logging import math import string from collections import OrderedDict +from typing import Tuple from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events.utils import copy_power_levels_contents +from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + Requester, + RoomAlias, + RoomID, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache @@ -52,18 +63,21 @@ class RoomCreationHandler(BaseHandler): "history_visibility": "shared", "original_invitees_have_ops": False, "guest_can_join": True, + "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, "history_visibility": "shared", "original_invitees_have_ops": True, "guest_can_join": True, + "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.PUBLIC_CHAT: { "join_rules": JoinRules.PUBLIC, "history_visibility": "shared", "original_invitees_have_ops": False, "guest_can_join": False, + "power_level_content_override": {}, }, } @@ -75,6 +89,8 @@ class RoomCreationHandler(BaseHandler): self.room_member_handler = hs.get_room_member_handler() self.config = hs.config + self._replication = hs.get_replication_data_handler() + # linearizer to stop two upgrades happening at once self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") @@ -88,19 +104,20 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() - @defer.inlineCallbacks - def upgrade_room(self, requester, old_room_id, new_version): + async def upgrade_room( + self, requester: Requester, old_room_id: str, new_version: RoomVersion + ): """Replace a room with a new room with a different version Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_version (unicode): the new room version to use + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_version: the new room version to use Returns: Deferred[unicode]: the new room id """ - yield self.ratelimit(requester) + await self.ratelimit(requester) user_id = requester.user.to_string() @@ -121,53 +138,56 @@ class RoomCreationHandler(BaseHandler): # If this user has sent multiple upgrade requests for the same room # and one of them is not complete yet, cache the response and # return it to all subsequent requests - ret = yield self._upgrade_response_cache.wrap( + ret = await self._upgrade_response_cache.wrap( (old_room_id, user_id), self._upgrade_room, requester, old_room_id, new_version, # args for _upgrade_room ) + return ret - @defer.inlineCallbacks - def _upgrade_room(self, requester, old_room_id, new_version): + async def _upgrade_room( + self, requester: Requester, old_room_id: str, new_version: RoomVersion + ): user_id = requester.user.to_string() # start by allocating a new room id - r = yield self.store.get_room(old_room_id) + r = await self.store.get_room(old_room_id) if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) - new_room_id = yield self._generate_room_id( - creator_id=user_id, is_public=r["is_public"] + new_room_id = await self._generate_room_id( + creator_id=user_id, is_public=r["is_public"], room_version=new_version, ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) # we create and auth the tombstone event before properly creating the new # room, to check our user has perms in the old room. - tombstone_event, tombstone_context = ( - yield self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - }, + ( + tombstone_event, + tombstone_context, + ) = await self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, }, - token_id=requester.access_token_id, - ) + }, + token_id=requester.access_token_id, ) - old_room_version = yield self.store.get_room_version(old_room_id) - yield self.auth.check_from_context( + old_room_version = await self.store.get_room_version_id(old_room_id) + await self.auth.check_from_context( old_room_version, tombstone_event, tombstone_context ) - yield self.clone_existing_room( + await self.clone_existing_room( requester, old_room_id=old_room_id, new_room_id=new_room_id, @@ -176,36 +196,44 @@ class RoomCreationHandler(BaseHandler): ) # now send the tombstone - yield self.event_creation_handler.send_nonmember_event( + await self.event_creation_handler.send_nonmember_event( requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + old_room_state = await tombstone_context.get_current_state_ids() # update any aliases - yield self._move_aliases_to_new_room( + await self._move_aliases_to_new_room( requester, old_room_id, new_room_id, old_room_state ) - # and finally, shut down the PLs in the old room, and update them in the new + # Copy over user push rules, tags and migrate room directory state + await self.room_member_handler.transfer_room_state_on_room_upgrade( + old_room_id, new_room_id + ) + + # finally, shut down the PLs in the old room, and update them in the new # room. - yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state + await self._update_upgraded_room_pls( + requester, old_room_id, new_room_id, old_room_state, ) return new_room_id - @defer.inlineCallbacks - def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state + async def _update_upgraded_room_pls( + self, + requester: Requester, + old_room_id: str, + new_room_id: str, + old_room_state: StateMap[str], ): """Send updated power levels in both rooms after an upgrade Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_room_id (unicode): the id of the replacement room - old_room_state (dict[tuple[str, str], str]): the state map for the old room + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_room_id: the id of the replacement room + old_room_state: the state map for the old room Returns: Deferred @@ -219,7 +247,7 @@ class RoomCreationHandler(BaseHandler): ) return - old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) + old_room_pl_state = await self.store.get_event(old_room_pl_event_id) # we try to stop regular users from speaking by setting the PL required # to send regular events and invites to 'Moderator' level. That's normally @@ -234,7 +262,7 @@ class RoomCreationHandler(BaseHandler): for v in ("invite", "events_default"): current = int(pl_content.get(v, 0)) if current < restricted_level: - logger.info( + logger.debug( "Setting level for %s in %s to %i (was %i)", v, old_room_id, @@ -244,11 +272,11 @@ class RoomCreationHandler(BaseHandler): pl_content[v] = restricted_level updated = True else: - logger.info("Not setting level for %s (already %i)", v, current) + logger.debug("Not setting level for %s (already %i)", v, current) if updated: try: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -262,8 +290,7 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - logger.info("Setting correct PLs in new room") - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -275,22 +302,25 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - @defer.inlineCallbacks - def clone_existing_room( - self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id + async def clone_existing_room( + self, + requester: Requester, + old_room_id: str, + new_room_id: str, + new_room_version: RoomVersion, + tombstone_event_id: str, ): """Populate a new room based on an old room Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_room_id (unicode): the id to give the new room (should already have been + requester: the user requesting the upgrade + old_room_id : the id of the room to be replaced + new_room_id: the id to give the new room (should already have been created with _gemerate_room_id()) - new_room_version (unicode): the new room version to use - tombstone_event_id (unicode|str): the ID of the tombstone event in the old - room. + new_room_version: the new room version to use + tombstone_event_id: the ID of the tombstone event in the old room. Returns: - Deferred[None] + Deferred """ user_id = requester.user.to_string() @@ -298,21 +328,21 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { - "room_version": new_room_version, + "room_version": new_room_version.identifier, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, } # Check if old room was non-federatable # Get old room's create event - old_room_create_event = yield self.store.get_create_event_for_room(old_room_id) + old_room_create_event = await self.store.get_create_event_for_room(old_room_id) # Check if the create event specified a non-federatable room if not old_room_create_event.content.get("m.federate", True): # If so, mark the new room as non-federatable as well creation_content["m.federate"] = False - initial_state = dict() + initial_state = {} # Replicate relevant room events types_to_copy = ( @@ -322,23 +352,52 @@ class RoomCreationHandler(BaseHandler): (EventTypes.RoomHistoryVisibility, ""), (EventTypes.GuestAccess, ""), (EventTypes.RoomAvatar, ""), - (EventTypes.Encryption, ""), + (EventTypes.RoomEncryption, ""), (EventTypes.ServerACL, ""), (EventTypes.RelatedGroups, ""), + (EventTypes.PowerLevels, ""), ) - old_room_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent - old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) + old_room_state_events = await self.store.get_events(old_room_state_ids.values()) for k, old_event_id in iteritems(old_room_state_ids): old_event = old_room_state_events.get(old_event_id) if old_event: initial_state[k] = old_event.content - yield self._send_events_for_new_room( + # deep-copy the power-levels event before we start modifying it + # note that if frozen_dicts are enabled, `power_levels` will be a frozen + # dict so we can't just copy.deepcopy it. + initial_state[ + (EventTypes.PowerLevels, "") + ] = power_levels = copy_power_levels_contents( + initial_state[(EventTypes.PowerLevels, "")] + ) + + # Resolve the minimum power level required to send any state event + # We will give the upgrading user this power level temporarily (if necessary) such that + # they are able to copy all of the state events over, then revert them back to their + # original power level afterwards in _update_upgraded_room_pls + + # Copy over user power levels now as this will not be possible with >100PL users once + # the room has been created + + # Calculate the minimum power level needed to clone the room + event_power_levels = power_levels.get("events", {}) + state_default = power_levels.get("state_default", 0) + ban = power_levels.get("ban") + needed_power_level = max(state_default, ban, max(event_power_levels.values())) + + # Raise the requester's power level in the new room if necessary + current_power_level = power_levels["users"][user_id] + if current_power_level < needed_power_level: + power_levels["users"][user_id] = needed_power_level + + await self._send_events_for_new_room( requester, new_room_id, # we expect to override all the presets with initial_state, so this is @@ -350,12 +409,12 @@ class RoomCreationHandler(BaseHandler): ) # Transfer membership events - old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_member_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent - old_room_member_state_events = yield self.store.get_events( + old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): @@ -364,7 +423,7 @@ class RoomCreationHandler(BaseHandler): "membership" in old_event.content and old_event.content["membership"] == "ban" ): - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(old_event["state_key"]), new_room_id, @@ -376,102 +435,93 @@ class RoomCreationHandler(BaseHandler): # XXX invites/joins # XXX 3pid invites - @defer.inlineCallbacks - def _move_aliases_to_new_room( - self, requester, old_room_id, new_room_id, old_room_state + async def _move_aliases_to_new_room( + self, + requester: Requester, + old_room_id: str, + new_room_id: str, + old_room_state: StateMap[str], ): - directory_handler = self.hs.get_handlers().directory_handler - - aliases = yield self.store.get_aliases_for_room(old_room_id) - # check to see if we have a canonical alias. - canonical_alias = None + canonical_alias_event = None canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event_id: - canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) - if canonical_alias_event: - canonical_alias = canonical_alias_event.content.get("alias", "") + canonical_alias_event = await self.store.get_event(canonical_alias_event_id) - # first we try to remove the aliases from the old room (we suppress sending - # the room_aliases event until the end). - # - # Note that we'll only be able to remove aliases that (a) aren't owned by an AS, - # and (b) unless the user is a server admin, which the user created. - # - # This is probably correct - given we don't allow such aliases to be deleted - # normally, it would be odd to allow it in the case of doing a room upgrade - - # but it makes the upgrade less effective, and you have to wonder why a room - # admin can't remove aliases that point to that room anyway. - # (cf https://github.com/matrix-org/synapse/issues/2360) - # - removed_aliases = [] - for alias_str in aliases: - alias = RoomAlias.from_string(alias_str) - try: - yield directory_handler.delete_association( - requester, alias, send_event=False - ) - removed_aliases.append(alias_str) - except SynapseError as e: - logger.warning("Unable to remove alias %s from old room: %s", alias, e) + await self.store.update_aliases_for_room(old_room_id, new_room_id) - # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest - # of this. - if not removed_aliases: + if not canonical_alias_event: return + # If there is a canonical alias we need to update the one in the old + # room and set one in the new one. + old_canonical_alias_content = dict(canonical_alias_event.content) + new_canonical_alias_content = {} + + canonical = canonical_alias_event.content.get("alias") + if canonical and self.hs.is_mine_id(canonical): + new_canonical_alias_content["alias"] = canonical + old_canonical_alias_content.pop("alias", None) + + # We convert to a list as it will be a Tuple. + old_alt_aliases = list(old_canonical_alias_content.get("alt_aliases", [])) + if old_alt_aliases: + old_canonical_alias_content["alt_aliases"] = old_alt_aliases + new_alt_aliases = new_canonical_alias_content.setdefault("alt_aliases", []) + for alias in canonical_alias_event.content.get("alt_aliases", []): + try: + if self.hs.is_mine_id(alias): + new_alt_aliases.append(alias) + old_alt_aliases.remove(alias) + except Exception: + logger.info( + "Invalid alias %s in canonical alias event %s", + alias, + canonical_alias_event_id, + ) + + if not old_alt_aliases: + old_canonical_alias_content.pop("alt_aliases") + + # If a canonical alias event existed for the old room, fire a canonical + # alias event for the new room with a copy of the information. try: - # this can fail if, for some reason, our user doesn't have perms to send - # m.room.aliases events in the old room (note that we've already checked that - # they have perms to send a tombstone event, so that's not terribly likely). - # - # If that happens, it's regrettable, but we should carry on: it's the same - # as when you remove an alias from the directory normally - it just means that - # the aliases event gets out of sync with the directory - # (cf https://github.com/vector-im/riot-web/issues/2369) - yield directory_handler.send_room_alias_update_event(requester, old_room_id) - except AuthError as e: - logger.warning("Failed to send updated alias event on old room: %s", e) - - # we can now add any aliases we successfully removed to the new room. - for alias in removed_aliases: - try: - yield directory_handler.create_association( - requester, - RoomAlias.from_string(alias), - new_room_id, - servers=(self.hs.hostname,), - send_event=False, - check_membership=False, - ) - logger.info("Moved alias %s to new room", alias) - except SynapseError as e: - # I'm not really expecting this to happen, but it could if the spam - # checking module decides it shouldn't, or similar. - logger.error("Error adding alias %s to new room: %s", alias, e) + await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": old_room_id, + "sender": requester.user.to_string(), + "content": old_canonical_alias_content, + }, + ratelimit=False, + ) + except SynapseError as e: + # again I'm not really expecting this to fail, but if it does, I'd rather + # we returned the new room to the client at this point. + logger.error("Unable to send updated alias events in old room: %s", e) try: - if canonical_alias and (canonical_alias in removed_aliases): - yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.CanonicalAlias, - "state_key": "", - "room_id": new_room_id, - "sender": requester.user.to_string(), - "content": {"alias": canonical_alias}, - }, - ratelimit=False, - ) - - yield directory_handler.send_room_alias_update_event(requester, new_room_id) + await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": new_room_id, + "sender": requester.user.to_string(), + "content": new_canonical_alias_content, + }, + ratelimit=False, + ) except SynapseError as e: # again I'm not really expecting this to fail, but if it does, I'd rather # we returned the new room to the client at this point. logger.error("Unable to send updated alias events in new room: %s", e) - @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): + async def create_room( + self, requester, config, ratelimit=True, creator_join_profile=None + ) -> Tuple[dict, int]: """ Creates a new room. Args: @@ -488,9 +538,9 @@ class RoomCreationHandler(BaseHandler): `avatar_url` and/or `displayname`. Returns: - Deferred[dict]: - a dict containing the keys `room_id` and, if an alias was - requested, `room_alias`. + First, a dict containing the keys `room_id` and, if an alias + was, requested, `room_alias`. Secondly, the stream_id of the + last persisted event. Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. @@ -499,7 +549,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) if ( self._server_notices_mxid is not None @@ -508,13 +558,17 @@ class RoomCreationHandler(BaseHandler): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. - yield self.third_party_event_rules.on_create_room( + event_allowed = await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) + if not event_allowed: + raise SynapseError( + 403, "You are not permitted to create rooms", Codes.FORBIDDEN + ) if not is_requester_admin and not self.spam_checker.user_may_create_room( user_id @@ -522,16 +576,17 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - yield self.ratelimit(requester) + await self.ratelimit(requester) - room_version = config.get( + room_version_id = config.get( "room_version", self.config.default_room_version.identifier ) - if not isinstance(room_version, string_types): + if not isinstance(room_version_id, string_types): raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) - if room_version not in KNOWN_ROOM_VERSIONS: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if room_version is None: raise SynapseError( 400, "Your homeserver does not support this room version", @@ -544,7 +599,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(400, "Invalid characters in room alias") room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) - mapping = yield self.store.get_association_from_room_alias(room_alias) + mapping = await self.store.get_association_from_room_alias(room_alias) if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) @@ -554,11 +609,12 @@ class RoomCreationHandler(BaseHandler): invite_list = config.get("invite", []) for i in invite_list: try: - UserID.from_string(i) + uid = UserID.from_string(i) + parse_and_validate_server_name(uid.domain) except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy(requester) + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") if ( @@ -577,19 +633,27 @@ class RoomCreationHandler(BaseHandler): visibility = config.get("visibility", None) is_public = visibility == "public" - room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public) + room_id = await self._generate_room_id( + creator_id=user_id, is_public=is_public, room_version=room_version, + ) directory_handler = self.hs.get_handlers().directory_handler if room_alias: - yield directory_handler.create_association( + await directory_handler.create_association( requester=requester, room_id=room_id, room_alias=room_alias, servers=[self.hs.hostname], - send_event=False, check_membership=False, ) + if is_public: + if not self.config.is_publishing_room_allowed(user_id, room_id, room_alias): + # Lets just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError(403, "Not allowed to publish room") + preset_config = config.get( "preset", RoomCreationPreset.PRIVATE_CHAT @@ -606,9 +670,9 @@ class RoomCreationHandler(BaseHandler): creation_content = config.get("creation_content", {}) # override any attempt to set room versions via the creation_content - creation_content["room_version"] = room_version + creation_content["room_version"] = room_version.identifier - yield self._send_events_for_new_room( + last_stream_id = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -622,7 +686,10 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -636,7 +703,10 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -654,7 +724,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - yield self.room_member_handler.update_membership( + _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -668,7 +738,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_room_member_handler().do_3pid_invite( + last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -683,12 +753,15 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - yield directory_handler.send_room_alias_update_event(requester, room_id) - return result + # Always wait for room creation to progate before returning + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", last_stream_id + ) + + return result, last_stream_id - @defer.inlineCallbacks - def _send_events_for_new_room( + async def _send_events_for_new_room( self, creator, # A Requester object. room_id, @@ -697,9 +770,15 @@ class RoomCreationHandler(BaseHandler): initial_state, creation_content, room_alias=None, - power_level_content_override=None, + power_level_content_override=None, # Doesn't apply when initial state has power level state event content creator_join_profile=None, - ): + ) -> int: + """Sends the initial events into a new room. + + Returns: + The stream_id of the last event persisted. + """ + def create(etype, content, **kwargs): e = {"type": etype, "content": content} @@ -708,13 +787,16 @@ class RoomCreationHandler(BaseHandler): return e - @defer.inlineCallbacks - def send(etype, content, **kwargs): + async def send(etype, content, **kwargs) -> int: event = create(etype, content, **kwargs) - logger.info("Sending %s in new room", etype) - yield self.event_creation_handler.create_and_send_nonmember_event( + logger.debug("Sending %s in new room", etype) + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) + return last_stream_id config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -723,10 +805,10 @@ class RoomCreationHandler(BaseHandler): event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send(etype=EventTypes.Create, content=creation_content) + await send(etype=EventTypes.Create, content=creation_content) - logger.info("Sending %s in new room", EventTypes.Member) - yield self.room_member_handler.update_membership( + logger.debug("Sending %s in new room", EventTypes.Member) + await self.room_member_handler.update_membership( creator, creator.user, room_id, @@ -739,7 +821,9 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send(etype=EventTypes.PowerLevels, content=pl_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=pl_content + ) else: power_level_content = { "users": {creator_id: 100}, @@ -750,52 +834,65 @@ class RoomCreationHandler(BaseHandler): EventTypes.RoomHistoryVisibility: 100, EventTypes.CanonicalAlias: 50, EventTypes.RoomAvatar: 50, + EventTypes.Tombstone: 100, + EventTypes.ServerACL: 100, + EventTypes.RoomEncryption: 100, }, "events_default": 0, "state_default": 50, "ban": 50, "kick": 50, "redact": 50, - "invite": 0, + "invite": 50, } if config["original_invitees_have_ops"]: for invitee in invite_list: power_level_content["users"][invitee] = 100 + # Power levels overrides are defined per chat preset + power_level_content.update(config["power_level_content_override"]) + if power_level_content_override: power_level_content.update(power_level_content_override) - yield send(etype=EventTypes.PowerLevels, content=power_level_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=power_level_content + ) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - yield send( + last_sent_stream_id = await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - yield send( + last_sent_stream_id = await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - yield send( + last_sent_stream_id = await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - yield send( + last_sent_stream_id = await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send(etype=etype, state_key=state_key, content=content) + last_sent_stream_id = await send( + etype=etype, state_key=state_key, content=content + ) - @defer.inlineCallbacks - def _generate_room_id(self, creator_id, is_public): + return last_sent_stream_id + + async def _generate_room_id( + self, creator_id: str, is_public: str, room_version: RoomVersion, + ): # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. attempts = 0 @@ -805,10 +902,11 @@ class RoomCreationHandler(BaseHandler): gen_room_id = RoomID(random_string, self.hs.hostname).to_string() if isinstance(gen_room_id, bytes): gen_room_id = gen_room_id.decode("utf-8") - yield self.store.store_room( + await self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, is_public=is_public, + room_version=room_version, ) return gen_room_id except StoreError: @@ -820,9 +918,10 @@ class RoomContextHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state - @defer.inlineCallbacks - def get_event_context(self, user, room_id, event_id, limit, event_filter): + async def get_event_context(self, user, room_id, event_id, limit, event_filter): """Retrieves events, pagination tokens and state around a given event in a room. @@ -841,31 +940,38 @@ class RoomContextHandler(object): before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit - users = yield self.store.get_users_in_room(room_id) + users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users def filter_evts(events): return filter_events_for_client( - self.store, user.to_string(), events, is_peeking=is_peeking + self.storage, user.to_string(), events, is_peeking=is_peeking ) - event = yield self.store.get_event( + event = await self.store.get_event( event_id, get_prev_content=True, allow_none=True ) if not event: return None - filtered = yield (filter_evts([event])) + filtered = await filter_evts([event]) if not filtered: raise AuthError(403, "You don't have permission to access that event.") - results = yield self.store.get_events_around( + results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) - results["events_before"] = yield filter_evts(results["events_before"]) - results["events_after"] = yield filter_evts(results["events_after"]) - results["event"] = event + if event_filter: + results["events_before"] = event_filter.filter(results["events_before"]) + results["events_after"] = event_filter.filter(results["events_after"]) + + results["events_before"] = await filter_evts(results["events_before"]) + results["events_after"] = await filter_evts(results["events_after"]) + # filter_evts can return a pruned event in case the user is allowed to see that + # there's something there but not see the content, so use the event that's in + # `filtered` rather than the event we retrieved from the datastore. + results["event"] = filtered[0] if results["events_after"]: last_event_id = results["events_after"][-1].event_id @@ -888,10 +994,15 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.store.get_state_for_events( + state = await self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) - results["state"] = list(state[last_event_id].values()) + + state_events = list(state[last_event_id].values()) + if event_filter: + state_events = event_filter.filter(state_events) + + results["state"] = await filter_evts(state_events) # We use a dummy token here as we only care about the room portion of # the token, which we replace. @@ -910,17 +1021,16 @@ class RoomEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_new_events( + async def get_new_events( self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None ): # We just ignore the key for now. - to_key = yield self.get_current_key() + to_key = await self.get_current_key() from_token = RoomStreamToken.parse(from_key) if from_token.topological: - logger.warn("Stream has topological part!!!! %r", from_key) + logger.warning("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) app_service = self.store.get_app_service_by_user_id(user.to_string()) @@ -929,11 +1039,11 @@ class RoomEventSource(object): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - room_events = yield self.store.get_membership_changes_for_user( + room_events = await self.store.get_membership_changes_for_user( user.to_string(), from_key, to_key ) - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=room_ids, from_key=from_key, to_key=to_key, @@ -961,15 +1071,3 @@ class RoomEventSource(object): def get_current_key_for_room(self, room_id): return self.store.get_room_events_max_id(room_id) - - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): - events, next_key = yield self.store.paginate_room_events( - room_id=key, - from_key=config.from_key, - to_key=config.to_key, - direction=config.direction, - limit=config.limit, - ) - - return (events, next_key) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index a7e55f00e5..4cbc02b0d0 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -15,9 +15,9 @@ import logging from collections import namedtuple +from typing import Any, Dict, Optional -from six import PY3, iteritems -from six.moves import range +from six import iteritems import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -27,7 +27,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, HttpResponseException from synapse.types import ThirdPartyInstanceID -from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache @@ -37,7 +36,6 @@ logger = logging.getLogger(__name__) REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 - # This is used to indicate we should only return rooms published to the main list. EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) @@ -72,6 +70,8 @@ class RoomListHandler(BaseHandler): This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. + from_federation (bool): true iff the request comes from the federation + API """ if not self.enable_room_list_search: return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) @@ -89,16 +89,12 @@ class RoomListHandler(BaseHandler): # appservice specific lists. logger.info("Bypassing cache as search request.") - # XXX: Quick hack to stop room directory queries taking too long. - # Timeout request after 60s. Probably want a more fundamental - # solution at some point - timeout = self.clock.time() + 60 return self._get_public_room_list( limit, since_token, search_filter, network_tuple=network_tuple, - timeout=timeout, + from_federation=from_federation, ) key = (limit, since_token, network_tuple) @@ -114,258 +110,124 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def _get_public_room_list( self, - limit=None, - since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - timeout=None, - ): + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[Dict] = None, + network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + from_federation: bool = False, + ) -> Dict[str, Any]: """Generate a public room list. Args: - limit (int|None): Maximum amount of rooms to return. - since_token (str|None) - search_filter (dict|None): Dictionary to filter rooms by. - network_tuple (ThirdPartyInstanceID): Which public list to use. + limit: Maximum amount of rooms to return. + since_token: + search_filter: Dictionary to filter rooms by. + network_tuple: Which public list to use. This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. - from_federation (bool): Whether this request originated from a + from_federation: Whether this request originated from a federating server or a client. Used for room filtering. - timeout (int|None): Amount of seconds to wait for a response before - timing out. """ - if since_token and since_token != "END": - since_token = RoomListNextBatch.from_token(since_token) - else: - since_token = None - rooms_to_order_value = {} - rooms_to_num_joined = {} + # Pagination tokens work by storing the room ID sent in the last batch, + # plus the direction (forwards or backwards). Next batch tokens always + # go forwards, prev batch tokens always go backwards. - newly_visible = [] - newly_unpublished = [] if since_token: - stream_token = since_token.stream_ordering - current_public_id = yield self.store.get_current_public_room_stream_id() - public_room_stream_id = since_token.public_room_stream_id - newly_visible, newly_unpublished = yield self.store.get_public_room_changes( - public_room_stream_id, current_public_id, network_tuple=network_tuple - ) - else: - stream_token = yield self.store.get_room_max_stream_ordering() - public_room_stream_id = yield self.store.get_current_public_room_stream_id() - - room_ids = yield self.store.get_public_room_ids_at_stream_id( - public_room_stream_id, network_tuple=network_tuple - ) - - # We want to return rooms in a particular order: the number of joined - # users. We then arbitrarily use the room_id as a tie breaker. - - @defer.inlineCallbacks - def get_order_for_room(room_id): - # Most of the rooms won't have changed between the since token and - # now (especially if the since token is "now"). So, we can ask what - # the current users are in a room (that will hit a cache) and then - # check if the room has changed since the since token. (We have to - # do it in that order to avoid races). - # If things have changed then fall back to getting the current state - # at the since token. - joined_users = yield self.store.get_users_in_room(room_id) - if self.store.has_room_changed_since(room_id, stream_token): - latest_event_ids = yield self.store.get_forward_extremeties_for_room( - room_id, stream_token - ) - - if not latest_event_ids: - return - - joined_users = yield self.state_handler.get_current_users_in_room( - room_id, latest_event_ids - ) + batch_token = RoomListNextBatch.from_token(since_token) - num_joined_users = len(joined_users) - rooms_to_num_joined[room_id] = num_joined_users + bounds = (batch_token.last_joined_members, batch_token.last_room_id) + forwards = batch_token.direction_is_forward + else: + batch_token = None + bounds = None - if num_joined_users == 0: - return + forwards = True - # We want larger rooms to be first, hence negating num_joined_users - rooms_to_order_value[room_id] = (-num_joined_users, room_id) + # we request one more than wanted to see if there are more pages to come + probing_limit = limit + 1 if limit is not None else None - logger.info( - "Getting ordering for %i rooms since %s", len(room_ids), stream_token + results = yield self.store.get_largest_public_rooms( + network_tuple, + search_filter, + probing_limit, + bounds=bounds, + forwards=forwards, + ignore_non_federatable=from_federation, ) - yield concurrently_execute(get_order_for_room, room_ids, 10) - sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) - sorted_rooms = [room_id for room_id, _ in sorted_entries] + def build_room_entry(room): + entry = { + "room_id": room["room_id"], + "name": room["name"], + "topic": room["topic"], + "canonical_alias": room["canonical_alias"], + "num_joined_members": room["joined_members"], + "avatar_url": room["avatar"], + "world_readable": room["history_visibility"] == "world_readable", + "guest_can_join": room["guest_access"] == "can_join", + } - # `sorted_rooms` should now be a list of all public room ids that is - # stable across pagination. Therefore, we can use indices into this - # list as our pagination tokens. + # Filter out Nones – rather omit the field altogether + return {k: v for k, v in entry.items() if v is not None} - # Filter out rooms that we don't want to return - rooms_to_scan = [ - r - for r in sorted_rooms - if r not in newly_unpublished and rooms_to_num_joined[r] > 0 - ] + results = [build_room_entry(r) for r in results] - total_room_count = len(rooms_to_scan) + response = {} + num_results = len(results) + if limit is not None: + more_to_come = num_results == probing_limit - if since_token: - # Filter out rooms we've already returned previously - # `since_token.current_limit` is the index of the last room we - # sent down, so we exclude it and everything before/after it. - if since_token.direction_is_forward: - rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :] + # Depending on direction we trim either the front or back. + if forwards: + results = results[:limit] else: - rooms_to_scan = rooms_to_scan[: since_token.current_limit] - rooms_to_scan.reverse() - - logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) - - # _append_room_entry_to_chunk will append to chunk but will stop if - # len(chunk) > limit - # - # Normally we will generate enough results on the first iteration here, - # but if there is a search filter, _append_room_entry_to_chunk may - # filter some results out, in which case we loop again. - # - # We don't want to scan over the entire range either as that - # would potentially waste a lot of work. - # - # XXX if there is no limit, we may end up DoSing the server with - # calls to get_current_state_ids for every single room on the - # server. Surely we should cap this somehow? - # - if limit: - step = limit + 1 + results = results[-limit:] else: - # step cannot be zero - step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1 - - chunk = [] - for i in range(0, len(rooms_to_scan), step): - if timeout and self.clock.time() > timeout: - raise Exception("Timed out searching room directory") - - batch = rooms_to_scan[i : i + step] - logger.info("Processing %i rooms for result", len(batch)) - yield concurrently_execute( - lambda r: self._append_room_entry_to_chunk( - r, - rooms_to_num_joined[r], - chunk, - limit, - search_filter, - from_federation=from_federation, - ), - batch, - 5, - ) - logger.info("Now %i rooms in result", len(chunk)) - if len(chunk) >= limit + 1: - break - - chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) - - # Work out the new limit of the batch for pagination, or None if we - # know there are no more results that would be returned. - # i.e., [since_token.current_limit..new_limit] is the batch of rooms - # we've returned (or the reverse if we paginated backwards) - # We tried to pull out limit + 1 rooms above, so if we have <= limit - # then we know there are no more results to return - new_limit = None - if chunk and (not limit or len(chunk) > limit): - - if not since_token or since_token.direction_is_forward: - if limit: - chunk = chunk[:limit] - last_room_id = chunk[-1]["room_id"] + more_to_come = False + + if num_results > 0: + final_entry = results[-1] + initial_entry = results[0] + + if forwards: + if batch_token: + # If there was a token given then we assume that there + # must be previous results. + response["prev_batch"] = RoomListNextBatch( + last_joined_members=initial_entry["num_joined_members"], + last_room_id=initial_entry["room_id"], + direction_is_forward=False, + ).to_token() + + if more_to_come: + response["next_batch"] = RoomListNextBatch( + last_joined_members=final_entry["num_joined_members"], + last_room_id=final_entry["room_id"], + direction_is_forward=True, + ).to_token() else: - if limit: - chunk = chunk[-limit:] - last_room_id = chunk[0]["room_id"] - - new_limit = sorted_rooms.index(last_room_id) - - results = {"chunk": chunk, "total_room_count_estimate": total_room_count} - - if since_token: - results["new_rooms"] = bool(newly_visible) - - if not since_token or since_token.direction_is_forward: - if new_limit is not None: - results["next_batch"] = RoomListNextBatch( - stream_ordering=stream_token, - public_room_stream_id=public_room_stream_id, - current_limit=new_limit, - direction_is_forward=True, - ).to_token() - - if since_token: - results["prev_batch"] = since_token.copy_and_replace( - direction_is_forward=False, - current_limit=since_token.current_limit + 1, - ).to_token() - else: - if new_limit is not None: - results["prev_batch"] = RoomListNextBatch( - stream_ordering=stream_token, - public_room_stream_id=public_room_stream_id, - current_limit=new_limit, - direction_is_forward=False, - ).to_token() - - if since_token: - results["next_batch"] = since_token.copy_and_replace( - direction_is_forward=True, - current_limit=since_token.current_limit - 1, - ).to_token() - - return results - - @defer.inlineCallbacks - def _append_room_entry_to_chunk( - self, - room_id, - num_joined_users, - chunk, - limit, - search_filter, - from_federation=False, - ): - """Generate the entry for a room in the public room list and append it - to the `chunk` if it matches the search filter - - Args: - room_id (str): The ID of the room. - num_joined_users (int): The number of joined users in the room. - chunk (list) - limit (int|None): Maximum amount of rooms to display. Function will - return if length of chunk is greater than limit + 1. - search_filter (dict|None) - from_federation (bool): Whether this request originated from a - federating server or a client. Used for room filtering. - """ - if limit and len(chunk) > limit + 1: - # We've already got enough, so lets just drop it. - return - - result = yield self.generate_room_entry(room_id, num_joined_users) - if not result: - return - - if from_federation and not result.get("m.federate", True): - # This is a room that other servers cannot join. Do not show them - # this room. - return + if batch_token: + response["next_batch"] = RoomListNextBatch( + last_joined_members=final_entry["num_joined_members"], + last_room_id=final_entry["room_id"], + direction_is_forward=True, + ).to_token() + + if more_to_come: + response["prev_batch"] = RoomListNextBatch( + last_joined_members=initial_entry["num_joined_members"], + last_room_id=initial_entry["room_id"], + direction_is_forward=False, + ).to_token() + + response["chunk"] = results + + response["total_room_count_estimate"] = yield self.store.count_public_rooms( + network_tuple, ignore_non_federatable=from_federation + ) - if _matches_room_entry(result, search_filter): - chunk.append(result) + return response @cachedInlineCallbacks(num_args=1, cache_context=True) def generate_room_entry( @@ -391,10 +253,21 @@ class RoomListHandler(BaseHandler): """ result = {"room_id": room_id, "num_joined_members": num_joined_users} + if with_alias: + aliases = yield self.store.get_aliases_for_room( + room_id, on_invalidate=cache_context.invalidate + ) + if aliases: + result["aliases"] = aliases + current_state_ids = yield self.store.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) + if not current_state_ids: + # We're not in the room, so may as well bail out here. + return result + event_map = yield self.store.get_events( [ event_id @@ -427,14 +300,7 @@ class RoomListHandler(BaseHandler): create_event = current_state.get((EventTypes.Create, "")) result["m.federate"] = create_event.content.get("m.federate", True) - if with_alias: - aliases = yield self.store.get_aliases_for_room( - room_id, on_invalidate=cache_context.invalidate - ) - if aliases: - result["aliases"] = aliases - - name_event = yield current_state.get((EventTypes.Name, "")) + name_event = current_state.get((EventTypes.Name, "")) if name_event: name = name_event.content.get("name", None) if name: @@ -580,18 +446,15 @@ class RoomListNextBatch( namedtuple( "RoomListNextBatch", ( - "stream_ordering", # stream_ordering of the first public room list - "public_room_stream_id", # public room stream id for first public room list - "current_limit", # The number of previous rooms returned + "last_joined_members", # The count to get rooms after/before + "last_room_id", # The room_id to get rooms after/before "direction_is_forward", # Bool if this is a next_batch, false if prev_batch ), ) ): - KEY_DICT = { - "stream_ordering": "s", - "public_room_stream_id": "p", - "current_limit": "n", + "last_joined_members": "m", + "last_room_id": "r", "direction_is_forward": "d", } @@ -599,13 +462,7 @@ class RoomListNextBatch( @classmethod def from_token(cls, token): - if PY3: - # The argument raw=False is only available on new versions of - # msgpack, and only really needed on Python 3. Gate it behind - # a PY3 check to avoid causing issues on Debian-packaged versions. - decoded = msgpack.loads(decode_base64(token), raw=False) - else: - decoded = msgpack.loads(decode_base64(token)) + decoded = msgpack.loads(decode_base64(token), raw=False) return RoomListNextBatch( **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 35450feb6f..0f7af982f0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,30 +17,26 @@ import abc import logging +from typing import Dict, Iterable, List, Optional, Tuple from six.moves import http_client -from signedjson.key import decode_verify_key_bytes -from signedjson.sign import verify_signed_json -from unpaddedbase64 import decode_base64 - -from twisted.internet import defer - from synapse import types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError -from synapse.handlers.identity import LookupAlgorithm, create_id_access_token_header -from synapse.types import RoomID, UserID +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.replication.http.membership import ( + ReplicationLocallyRejectInviteRestServlet, +) +from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room -from synapse.util.hash import sha256_and_url_safe_base64 from ._base import BaseHandler logger = logging.getLogger(__name__) -id_server_scheme = "https://" - class RoomMemberHandler(object): # TODO(paul): This handler currently contains a messy conflation of @@ -51,20 +47,15 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config - self.simple_http_client = hs.get_simple_http_client() self.federation_handler = hs.get_handlers().federation_handler self.directory_handler = hs.get_handlers().directory_handler + self.identity_handler = hs.get_handlers().identity_handler self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() @@ -78,88 +69,117 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles + self._event_stream_writer_instance = hs.config.worker.writers.events + self._is_on_event_persistence_instance = ( + self._event_stream_writer_instance == hs.get_instance_name() + ) + if self._is_on_event_persistence_instance: + self.persist_event_storage = hs.get_storage().persistence + else: + self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client( + hs + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. self.base_handler = BaseHandler(hs) @abc.abstractmethod - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Tuple[str, int]: """Try and join a room that this server is not in Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers that can be used - to join via. - room_id (str): Room that we are trying to join - user (UserID): User who is trying to join - content (dict): A dict that should be used as the content of the - join event. - - Returns: - Deferred + requester + remote_room_hosts: List of servers that can be used to join via. + room_id: Room that we are trying to join + user: User who is trying to join + content: A dict that should be used as the content of the join event. """ raise NotImplementedError() @abc.abstractmethod - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + async def _remote_reject_invite( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> Tuple[Optional[str], int]: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers to use to try and - reject invite - room_id (str) - target (UserID): The user rejecting the invite + requester + remote_room_hosts: List of servers to use to try and reject invite + room_id + target: The user rejecting the invite + content: The content for the rejection event Returns: - Deferred[dict]: A dictionary to be returned to the client, may + A dictionary to be returned to the client, may include event_id etc, or nothing if we locally rejected """ raise NotImplementedError() + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + if self._is_on_event_persistence_instance: + return await self.persist_event_storage.locally_reject_invite( + user_id, room_id + ) + else: + result = await self._locally_reject_client( + instance_name=self._event_stream_writer_instance, + user_id=user_id, + room_id=room_id, + ) + return result["stream_id"] + @abc.abstractmethod - def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the room. Args: - target (UserID) - room_id (str) - - Returns: - Deferred|None + target + room_id """ raise NotImplementedError() @abc.abstractmethod - def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. Args: - target (UserID) - room_id (str) - - Returns: - Deferred|None + target + room_id """ raise NotImplementedError() - @defer.inlineCallbacks - def _local_membership_update( + async def _local_membership_update( self, - requester, - target, - room_id, - membership, - prev_events_and_hashes, - txn_id=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + membership: str, + prev_event_ids: Collection[str], + txn_id: Optional[str] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Tuple[str, int]: user_id = target.to_string() if content is None: @@ -169,7 +189,7 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - event, context = yield self.event_creation_handler.create_event( + event, context = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -182,23 +202,24 @@ class RoomMemberHandler(object): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, require_consent=require_consent, ) # Check if this event matches the previous membership event for the user. - duplicate = yield self.event_creation_handler.deduplicate_state_event( + duplicate = await self.event_creation_handler.deduplicate_state_event( event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return duplicate + _, stream_id = await self.store.get_event_ordering(duplicate.event_id) + return duplicate.event_id, stream_id - yield self.event_creation_handler.handle_new_client_event( + stream_id = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -208,48 +229,30 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target, room_id) - - # Copy over direct message status and room tags if this is a join - # on an upgraded room - - # Check if this is an upgraded room - predecessor = yield self.store.get_room_predecessor(room_id) - - if predecessor: - # It is an upgraded room. Copy over old tags - self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], room_id, user_id - ) - # Move over old push rules - self.store.move_push_rules_from_room_to_room_for_user( - predecessor["room_id"], room_id, user_id - ) + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target, room_id) + await self._user_left_room(target, room_id) - return event + return event.event_id, stream_id - @defer.inlineCallbacks - def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): + async def copy_room_tags_and_direct_to_room( + self, old_room_id, new_room_id, user_id + ) -> None: """Copies the tags and direct room state from one room to another. Args: - old_room_id (str) - new_room_id (str) - user_id (str) - - Returns: - Deferred[None] + old_room_id: The room ID of the old room. + new_room_id: The room ID of the new room. + user_id: The user's ID. """ # Retrieve user account data for predecessor room - user_account_data, _ = yield self.store.get_account_data_for_user(user_id) + user_account_data, _ = await self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get("m.direct", {}) @@ -262,36 +265,35 @@ class RoomMemberHandler(object): direct_rooms[key].append(new_room_id) # Save back to user's m.direct account data - yield self.store.add_account_data_for_user( + await self.store.add_account_data_for_user( user_id, "m.direct", direct_rooms ) break # Copy room tags if applicable - room_tags = yield self.store.get_tags_for_room(user_id, old_room_id) + room_tags = await self.store.get_tags_for_room(user_id, old_room_id) # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) + await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) - @defer.inlineCallbacks - def update_membership( + async def update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Tuple[Optional[str], int]: key = (room_id,) - with (yield self.member_linearizer.queue(key)): - result = yield self._update_membership( + with (await self.member_linearizer.queue(key)): + result = await self._update_membership( requester, target, room_id, @@ -306,20 +308,19 @@ class RoomMemberHandler(object): return result - @defer.inlineCallbacks - def _update_membership( + async def _update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Tuple[Optional[str], int]: content_specified = bool(content) if content is None: content = {} @@ -342,7 +343,7 @@ class RoomMemberHandler(object): # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: - yield self.federation_handler.exchange_third_party_invite( + await self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -353,7 +354,7 @@ class RoomMemberHandler(object): remote_room_hosts = [] if effective_membership_state not in ("leave", "ban"): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -372,7 +373,7 @@ class RoomMemberHandler(object): is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -391,10 +392,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) - latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = yield self.state_handler.get_current_state_ids( + current_state_ids = await self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids ) @@ -403,7 +403,7 @@ class RoomMemberHandler(object): # transitions and generic otherwise old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) if old_state_id: - old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( @@ -424,7 +424,13 @@ class RoomMemberHandler(object): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - return old_state + _, stream_id = await self.store.get_event_ordering( + old_state.event_id + ) + return ( + old_state.event_id, + stream_id, + ) if old_membership in ["ban", "leave"] and action == "kick": raise AuthError(403, "The target user is not in the room") @@ -435,7 +441,7 @@ class RoomMemberHandler(object): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = yield self._is_server_notice_room(room_id) + is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( http_client.FORBIDDEN, @@ -446,18 +452,18 @@ class RoomMemberHandler(object): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = yield self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(current_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -465,21 +471,22 @@ class RoomMemberHandler(object): profile = self.profile_handler if not content_specified: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + content["displayname"] = await profile.get_displayname(target) + content["avatar_url"] = await profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" - ret = yield self._remote_join( + remote_join_response = await self._remote_join( requester, remote_room_hosts, room_id, target, content ) - return ret + + return remote_join_response elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -493,36 +500,116 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target + return await self._remote_reject_invite( + requester, remote_room_hosts, room_id, target, content, ) - return res - res = yield self._local_membership_update( + return await self._local_membership_update( requester=requester, target=target, room_id=room_id, membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=latest_event_ids, content=content, require_consent=require_consent, ) - return res - @defer.inlineCallbacks - def send_membership_event(self, requester, event, context, ratelimit=True): + async def transfer_room_state_on_room_upgrade( + self, old_room_id: str, room_id: str + ) -> None: + """Upon our server becoming aware of an upgraded room, either by upgrading a room + ourselves or joining one, we can transfer over information from the previous room. + + Copies user state (tags/push rules) for every local user that was in the old room, as + well as migrating the room directory state. + + Args: + old_room_id: The ID of the old room + room_id: The ID of the new room + """ + logger.info("Transferring room state from %s to %s", old_room_id, room_id) + + # Find all local users that were in the old room and copy over each user's state + users = await self.store.get_users_in_room(old_room_id) + await self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) + + # Add new room to the room directory if the old room was there + # Remove old room from the room directory + old_room = await self.store.get_room(old_room_id) + if old_room and old_room["is_public"]: + await self.store.set_room_is_public(old_room_id, False) + await self.store.set_room_is_public(room_id, True) + + # Transfer alias mappings in the room directory + await self.store.update_aliases_for_room(old_room_id, room_id) + + # Check if any groups we own contain the predecessor room + local_group_ids = await self.store.get_local_groups_for_room(old_room_id) + for group_id in local_group_ids: + # Add new the new room to those groups + await self.store.add_room_to_group(group_id, room_id, old_room["is_public"]) + + # Remove the old room from those groups + await self.store.remove_room_from_group(group_id, old_room_id) + + async def copy_user_state_on_room_upgrade( + self, old_room_id: str, new_room_id: str, user_ids: Iterable[str] + ) -> None: + """Copy user-specific information when they join a new room when that new room is the + result of a room upgrade + + Args: + old_room_id: The ID of upgraded room + new_room_id: The ID of the new room + user_ids: User IDs to copy state for + """ + + logger.debug( + "Copying over room tags and push rules from %s to %s for users %s", + old_room_id, + new_room_id, + user_ids, + ) + + for user_id in user_ids: + try: + # It is an upgraded room. Copy over old tags + await self.copy_room_tags_and_direct_to_room( + old_room_id, new_room_id, user_id + ) + # Copy over push rules + await self.store.copy_push_rules_from_room_to_room_for_user( + old_room_id, new_room_id, user_id + ) + except Exception: + logger.exception( + "Error copying tags and/or push rules from rooms %s to %s for user %s. " + "Skipping...", + old_room_id, + new_room_id, + user_id, + ) + continue + + async def send_membership_event( + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + ): """ Change the membership status of a user in a room. Args: - requester (Requester): The local user who requested the membership + requester: The local user who requested the membership event. If None, certain checks, like whether this homeserver can act as the sender, will be skipped. - event (SynapseEvent): The membership event. + event: The membership event. context: The context of the event. - ratelimit (bool): Whether to rate limit this request. + ratelimit: Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. """ @@ -538,27 +625,27 @@ class RoomMemberHandler(object): else: requester = types.create_requester(target_user) - prev_event = yield self.event_creation_handler.deduplicate_state_event( + prev_event = await self.event_creation_handler.deduplicate_state_event( event, context ) if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(prev_state_ids) + guest_can_join = await self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if event.membership not in (Membership.LEAVE, Membership.BAN): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) @@ -572,18 +659,19 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target_user, room_id) + await self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target_user, room_id) + await self._user_left_room(target_user, room_id) - @defer.inlineCallbacks - def _can_guest_join(self, current_state_ids): + async def _can_guest_join( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -591,7 +679,7 @@ class RoomMemberHandler(object): if not guest_access_id: return False - guest_access = yield self.store.get_event(guest_access_id) + guest_access = await self.store.get_event(guest_access_id) return ( guest_access @@ -600,13 +688,14 @@ class RoomMemberHandler(object): and guest_access.content["guest_access"] == "can_join" ) - @defer.inlineCallbacks - def lookup_room_alias(self, room_alias): + async def lookup_room_alias( + self, room_alias: RoomAlias + ) -> Tuple[RoomID, List[str]]: """ Get the room ID associated with a room alias. Args: - room_alias (RoomAlias): The alias to look up. + room_alias: The alias to look up. Returns: A tuple of: The room ID as a RoomID object. @@ -615,7 +704,7 @@ class RoomMemberHandler(object): SynapseError if room alias could not be found. """ directory_handler = self.directory_handler - mapping = yield directory_handler.get_association(room_alias) + mapping = await directory_handler.get_association(room_alias) if not mapping: raise SynapseError(404, "No such room alias") @@ -630,28 +719,27 @@ class RoomMemberHandler(object): return RoomID.from_string(room_id), servers - @defer.inlineCallbacks - def _get_inviter(self, user_id, room_id): - invite = yield self.store.get_invite_for_user_in_room( + async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]: + invite = await self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: return UserID.from_string(invite.sender) + return None - @defer.inlineCallbacks - def do_3pid_invite( + async def do_3pid_invite( self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id, - id_access_token=None, - ): + room_id: str, + inviter: UserID, + medium: str, + address: str, + id_server: str, + requester: Requester, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> int: if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -659,9 +747,9 @@ class RoomMemberHandler(object): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - yield self.base_handler.ratelimit(requester) + await self.base_handler.ratelimit(requester) - can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id ) if not can_invite: @@ -676,14 +764,16 @@ class RoomMemberHandler(object): 403, "Looking up third-party identifiers is denied from this server" ) - invitee = yield self._lookup_3pid(id_server, medium, address, id_access_token) + invitee = await self.identity_handler.lookup_3pid( + id_server, medium, address, id_access_token + ) if invitee: - yield self.update_membership( + _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - yield self._make_and_store_3pid_invite( + stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -694,215 +784,20 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - @defer.inlineCallbacks - def _lookup_3pid(self, id_server, medium, address, id_access_token=None): - """Looks up a 3pid in the passed identity server. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - id_access_token (str|None): The access token to authenticate to the identity - server with - - Returns: - str|None: the matrix ID of the 3pid, or None if it is not recognized. - """ - if id_access_token is not None: - try: - results = yield self._lookup_3pid_v2( - id_server, id_access_token, medium, address - ) - return results - - except Exception as e: - # Catch HttpResponseExcept for a non-200 response code - # Check if this identity server does not know about v2 lookups - if isinstance(e, HttpResponseException) and e.code == 404: - # This is an old identity server that does not yet support v2 lookups - logger.warning( - "Attempted v2 lookup on v1 identity server %s. Falling " - "back to v1", - id_server, - ) - else: - logger.warning("Error when looking up hashing details: %s", e) - return None - - return (yield self._lookup_3pid_v1(id_server, medium, address)) - - @defer.inlineCallbacks - def _lookup_3pid_v1(self, id_server, medium, address): - """Looks up a 3pid in the passed identity server using v1 lookup. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - - Returns: - str: the matrix ID of the 3pid, or None if it is not recognized. - """ - try: - data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), - {"medium": medium, "address": address}, - ) - - if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - yield self._verify_any_signature(data, id_server) - return data["mxid"] - - except IOError as e: - logger.warning("Error from v1 identity server lookup: %s" % (e,)) - - return None - - @defer.inlineCallbacks - def _lookup_3pid_v2(self, id_server, id_access_token, medium, address): - """Looks up a 3pid in the passed identity server using v2 lookup. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - id_access_token (str): The access token to authenticate to the identity server with - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - - Returns: - Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised. - """ - # Check what hashing details are supported by this identity server - hash_details = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), - {"access_token": id_access_token}, - ) - - if not isinstance(hash_details, dict): - logger.warning( - "Got non-dict object when checking hash details of %s%s: %s", - id_server_scheme, - id_server, - hash_details, - ) - raise SynapseError( - 400, - "Non-dict object from %s%s during v2 hash_details request: %s" - % (id_server_scheme, id_server, hash_details), - ) - - # Extract information from hash_details - supported_lookup_algorithms = hash_details.get("algorithms") - lookup_pepper = hash_details.get("lookup_pepper") - if ( - not supported_lookup_algorithms - or not isinstance(supported_lookup_algorithms, list) - or not lookup_pepper - or not isinstance(lookup_pepper, str) - ): - raise SynapseError( - 400, - "Invalid hash details received from identity server %s%s: %s" - % (id_server_scheme, id_server, hash_details), - ) - - # Check if any of the supported lookup algorithms are present - if LookupAlgorithm.SHA256 in supported_lookup_algorithms: - # Perform a hashed lookup - lookup_algorithm = LookupAlgorithm.SHA256 - - # Hash address, medium and the pepper with sha256 - to_hash = "%s %s %s" % (address, medium, lookup_pepper) - lookup_value = sha256_and_url_safe_base64(to_hash) - - elif LookupAlgorithm.NONE in supported_lookup_algorithms: - # Perform a non-hashed lookup - lookup_algorithm = LookupAlgorithm.NONE - - # Combine together plaintext address and medium - lookup_value = "%s %s" % (address, medium) - - else: - logger.warning( - "None of the provided lookup algorithms of %s are supported: %s", - id_server, - supported_lookup_algorithms, - ) - raise SynapseError( - 400, - "Provided identity server does not support any v2 lookup " - "algorithms that this homeserver supports.", - ) - - # Authenticate with identity server given the access token from the client - headers = {"Authorization": create_id_access_token_header(id_access_token)} - - try: - lookup_results = yield self.simple_http_client.post_json_get_json( - "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), - { - "addresses": [lookup_value], - "algorithm": lookup_algorithm, - "pepper": lookup_pepper, - }, - headers=headers, - ) - except Exception as e: - logger.warning("Error when performing a v2 3pid lookup: %s", e) - raise SynapseError( - 500, "Unknown error occurred during identity server lookup" - ) + return stream_id - # Check for a mapping from what we looked up to an MXID - if "mappings" not in lookup_results or not isinstance( - lookup_results["mappings"], dict - ): - logger.warning("No results from 3pid lookup") - return None - - # Return the MXID if it's available, or None otherwise - mxid = lookup_results["mappings"].get(lookup_value) - return mxid - - @defer.inlineCallbacks - def _verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" - % (id_server_scheme, server_hostname, key_name) - ) - if "public_key" not in key_data: - raise AuthError( - 401, "No public key named %s from %s" % (key_name, server_hostname) - ) - verify_signed_json( - data, - server_hostname, - decode_verify_key_bytes( - key_name, decode_base64(key_data["public_key"]) - ), - ) - return - - @defer.inlineCallbacks - def _make_and_store_3pid_invite( + async def _make_and_store_3pid_invite( self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id, - id_access_token=None, - ): - room_state = yield self.state_handler.get_current_state(room_id) + requester: Requester, + id_server: str, + medium: str, + address: str, + room_id: str, + user: UserID, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> int: + room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -935,25 +830,31 @@ class RoomMemberHandler(object): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_keys, fallback_public_key, display_name = ( - yield self._ask_id_server_for_third_party_invite( - requester=requester, - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url, - id_access_token=id_access_token, - ) + ( + token, + public_keys, + fallback_public_key, + display_name, + ) = await self.identity_handler.ask_id_server_for_third_party_invite( + requester=requester, + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url, + id_access_token=id_access_token, ) - yield self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -971,146 +872,11 @@ class RoomMemberHandler(object): ratelimit=False, txn_id=txn_id, ) + return stream_id - @defer.inlineCallbacks - def _ask_id_server_for_third_party_invite( - self, - requester, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url, - id_access_token=None, - ): - """ - Asks an identity server for a third party invite. - - Args: - requester (Requester) - id_server (str): hostname + optional port for the identity server. - medium (str): The literal string "email". - address (str): The third party address being invited. - room_id (str): The ID of the room to which the user is invited. - inviter_user_id (str): The user ID of the inviter. - room_alias (str): An alias for the room, for cosmetic notifications. - room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - room_join_rules (str): The join rules of the email (e.g. "public"). - room_name (str): The m.room.name of the room. - inviter_display_name (str): The current display name of the - inviter. - inviter_avatar_url (str): The URL of the inviter's avatar. - id_access_token (str|None): The access token to authenticate to the identity - server with - - Returns: - A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. - """ - invite_config = { - "medium": medium, - "address": address, - "room_id": room_id, - "room_alias": room_alias, - "room_avatar_url": room_avatar_url, - "room_join_rules": room_join_rules, - "room_name": room_name, - "sender": inviter_user_id, - "sender_display_name": inviter_display_name, - "sender_avatar_url": inviter_avatar_url, - } - - # Add the identity service access token to the JSON body and use the v2 - # Identity Service endpoints if id_access_token is present - data = None - base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) - - if id_access_token: - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, - ) - - # Attempt a v2 lookup - url = base_url + "/v2/store-invite" - try: - data = yield self.simple_http_client.post_json_get_json( - url, - invite_config, - {"Authorization": create_id_access_token_header(id_access_token)}, - ) - except HttpResponseException as e: - if e.code != 404: - logger.info("Failed to POST %s with JSON: %s", url, e) - raise e - - if data is None: - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, - id_server, - ) - url = base_url + "/api/v1/store-invite" - - try: - data = yield self.simple_http_client.post_json_get_json( - url, invite_config - ) - except HttpResponseException as e: - logger.warning( - "Error trying to call /store-invite on %s%s: %s", - id_server_scheme, - id_server, - e, - ) - - if data is None: - # Some identity servers may only support application/x-www-form-urlencoded - # types. This is especially true with old instances of Sydent, see - # https://github.com/matrix-org/sydent/pull/170 - try: - data = yield self.simple_http_client.post_urlencoded_get_json( - url, invite_config - ) - except HttpResponseException as e: - logger.warning( - "Error calling /store-invite on %s%s with fallback " - "encoding: %s", - id_server_scheme, - id_server, - e, - ) - raise e - - # TODO: Check for success - token = data["token"] - public_keys = data.get("public_keys", []) - if "public_key" in data: - fallback_public_key = { - "public_key": data["public_key"], - "key_validity_url": key_validity_url, - } - else: - fallback_public_key = public_keys[0] - - if not public_keys: - public_keys.append(fallback_public_key) - display_name = data["display_name"] - return token, public_keys, fallback_public_key, display_name - - @defer.inlineCallbacks - def _is_host_in_room(self, current_state_ids): + async def _is_host_in_room( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) @@ -1123,7 +889,7 @@ class RoomMemberHandler(object): continue event_id = current_state_ids[(etype, state_key)] - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event: continue @@ -1132,11 +898,10 @@ class RoomMemberHandler(object): return False - @defer.inlineCallbacks - def _is_server_notice_room(self, room_id): + async def _is_server_notice_room(self, room_id: str) -> bool: if self._server_notices_mxid is None: return False - user_ids = yield self.store.get_users_in_room(room_id) + user_ids = await self.store.get_users_in_room(room_id) return self._server_notices_mxid in user_ids @@ -1148,20 +913,21 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") - @defer.inlineCallbacks - def _is_remote_room_too_complex(self, room_id, remote_room_hosts): + async def _is_remote_room_too_complex( + self, room_id: str, remote_room_hosts: List[str] + ) -> Optional[bool]: """ Check if complexity of a remote room is too great. Args: - room_id (str) - remote_room_hosts (list[str]) + room_id + remote_room_hosts Returns: bool of whether the complexity is too great, or None if unable to be fetched """ max_complexity = self.hs.config.limit_remote_rooms.complexity - complexity = yield self.federation_handler.get_room_complexity( + complexity = await self.federation_handler.get_room_complexity( remote_room_hosts, room_id ) @@ -1169,23 +935,26 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity return None - @defer.inlineCallbacks - def _is_local_room_too_complex(self, room_id): + async def _is_local_room_too_complex(self, room_id: str) -> bool: """ Check if the complexity of a local room is too great. Args: - room_id (str) - - Returns: bool + room_id: The room ID to check for complexity. """ max_complexity = self.hs.config.limit_remote_rooms.complexity - complexity = yield self.store.get_room_complexity(room_id) + complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity - @defer.inlineCallbacks - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -1200,7 +969,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: # Fetch the room complexity - too_complex = yield self._is_remote_room_too_complex( + too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts ) if too_complex is True: @@ -1214,28 +983,28 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - yield self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) - yield self._user_joined_room(user, room_id) + await self._user_joined_room(user, room_id) # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. if self.hs.config.limit_remote_rooms.enabled: if too_complex is False: # We checked, and we're under the limit. - return + return event_id, stream_id # Check again, but with the local state events - too_complex = yield self._is_local_room_too_complex(room_id) + too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. - return + return event_id, stream_id # The room is too large. Leave. requester = types.create_requester(user, None, False, None) - yield self.update_membership( + await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) raise SynapseError( @@ -1244,16 +1013,24 @@ class RoomMemberMasterHandler(RoomMemberHandler): errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) - @defer.inlineCallbacks - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + return event_id, stream_id + + async def _remote_reject_invite( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: - ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, target.to_string() + event, stream_id = await fed_handler.do_remotely_reject_invite( + remote_room_hosts, room_id, target.to_string(), content=content, ) - return ret + return event.event_id, stream_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -1263,24 +1040,23 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite(target.to_string(), room_id) - return {} + stream_id = await self.locally_reject_invite(target.to_string(), room_id) + return None, stream_id - def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + user_joined_room(self.distributor, target, room_id) - def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + user_left_room(self.distributor, target, room_id) - @defer.inlineCallbacks - def forget(self, user, room_id): + async def forget(self, user: UserID, room_id: str) -> None: user_id = user.to_string() - member = yield self.state_handler.get_current_state( + member = await self.state_handler.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None @@ -1292,4 +1068,4 @@ class RoomMemberMasterHandler(RoomMemberHandler): raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) if membership: - yield self.store.forget(user_id, room_id) + await self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 75e96ae1a2..02e0c4103d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -24,6 +23,7 @@ from synapse.replication.http.membership import ( ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) +from synapse.types import Requester, UserID logger = logging.getLogger(__name__) @@ -36,14 +36,20 @@ class RoomMemberWorkerHandler(RoomMemberHandler): self._remote_reject_client = ReplRejectInvite.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) - @defer.inlineCallbacks - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - ret = yield self._remote_join_client( + ret = await self._remote_join_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, @@ -51,30 +57,39 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - yield self._user_joined_room(user, room_id) + await self._user_joined_room(user, room_id) - return ret + return ret["event_id"], ret["stream_id"] - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + async def _remote_reject_invite( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ - return self._remote_reject_client( + ret = await self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), + content=content, ) + return ret["event_id"], ret["stream_id"] - def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="joined" ) - def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a1ce6929cf..abecaa8313 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,45 +13,93 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Callable, Dict, Optional, Set, Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string -from synapse.rest.client.v1.login import SSOAuthHandler +from synapse.http.site import SynapseRequest +from synapse.module_api import ModuleApi +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) +from synapse.util.async_helpers import Linearizer +from synapse.util.iterutils import chunk_seq logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + # The user interactive authentication session ID associated with this SAML + # session (or None if this SAML session is for an initial login). + ui_auth_session_id = attr.ib(type=Optional[str], default=None) + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._sso_auth_handler = SSOAuthHandler(hs) + self._auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._registration_handler = hs.get_registration_handler() + + self._clock = hs.get_clock() + self._datastore = hs.get_datastore() + self._hostname = hs.hostname + self._saml2_session_lifetime = hs.config.saml2_session_lifetime + self._grandfathered_mxid_source_attribute = ( + hs.config.saml2_grandfathered_mxid_source_attribute + ) + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config, + ModuleApi(hs, hs.get_auth_handler()), + ) + + # identifier for the external_ids table + self._auth_provider_id = "saml" # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} - self._clock = hs.get_clock() - self._saml2_session_lifetime = hs.config.saml2_session_lifetime + # a lock on the mappings + self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) - def handle_redirect_request(self, client_redirect_url): + def handle_redirect_request( + self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None + ) -> bytes: """Handle an incoming request to /login/sso/redirect Args: - client_redirect_url (bytes): the URL that we should redirect the + client_redirect_url: the URL that we should redirect the client to when everything is done + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). Returns: - bytes: URL to redirect to + URL to redirect to """ reqid, info = self._saml_client.prepare_for_authenticate( relay_state=client_redirect_url ) now = self._clock.time_msec() - self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now) + self._outstanding_requests_dict[reqid] = Saml2SessionData( + creation_time=now, ui_auth_session_id=ui_auth_session_id, + ) for key, value in info["headers"]: if key == "Location": @@ -60,15 +108,15 @@ class SamlHandler: # this shouldn't happen! raise Exception("prepare_for_authenticate didn't return a Location header") - def handle_saml_response(self, request): + async def handle_saml_response(self, request: SynapseRequest) -> None: """Handle an incoming request to /_matrix/saml2/authn_response Args: - request (SynapseRequest): the incoming request from the browser. We'll + request: the incoming request from the browser. We'll respond to it with a redirect. Returns: - Deferred[none]: Completes once we have handled the request. + Completes once we have handled the request. """ resp_bytes = parse_string(request, "SAMLResponse", required=True) relay_state = parse_string(request, "RelayState", required=True) @@ -77,6 +125,37 @@ class SamlHandler: # the dict. self.expire_sessions() + user_id, current_session = await self._map_saml_response_to_user( + resp_bytes, relay_state + ) + + # Complete the interactive auth session or the login. + if current_session and current_session.ui_auth_session_id: + await self._auth_handler.complete_sso_ui_auth( + user_id, current_session.ui_auth_session_id, request + ) + + else: + await self._auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user( + self, resp_bytes: str, client_redirect_url: str + ) -> Tuple[str, Optional[Saml2SessionData]]: + """ + Given a sample response, retrieve the cached session and user for it. + + Args: + resp_bytes: The SAML response. + client_redirect_url: The redirect URL passed in by the client. + + Returns: + Tuple of the user ID and SAML session associated with this response. + + Raises: + SynapseError if there was a problem with the response. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. + """ try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -84,26 +163,123 @@ class SamlHandler: outstanding=self._outstanding_requests_dict, ) except Exception as e: - logger.warning("Exception parsing SAML2 response: %s", e) raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,)) if saml2_auth.not_signed: - logger.warning("SAML2 response was not signed") raise SynapseError(400, "SAML2 response was not signed") - if "uid" not in saml2_auth.ava: - logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") + logger.debug("SAML2 response: %s", saml2_auth.origxml) + for assertion in saml2_auth.assertions: + # kibana limits the length of a log field, whereas this is all rather + # useful, so split it up. + count = 0 + for part in chunk_seq(str(assertion), 10000): + logger.info( + "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part + ) + count += 1 - self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) - username = saml2_auth.ava["uid"][0] - displayName = saml2_auth.ava.get("displayName", [None])[0] + current_session = self._outstanding_requests_dict.pop( + saml2_auth.in_response_to, None + ) - return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, user_display_name=displayName + remote_user_id = self._user_mapping_provider.get_remote_user_id( + saml2_auth, client_redirect_url ) + if not remote_user_id: + raise Exception("Failed to extract remote user id from SAML response") + + with (await self._mapping_lock.queue(self._auth_provider_id)): + # first of all, check if we already have a mapping for this user + logger.info( + "Looking for existing mapping for user %s:%s", + self._auth_provider_id, + remote_user_id, + ) + registered_user_id = await self._datastore.get_user_by_external_id( + self._auth_provider_id, remote_user_id + ) + if registered_user_id is not None: + logger.info("Found existing mapping %s", registered_user_id) + return registered_user_id, current_session + + # backwards-compatibility hack: see if there is an existing user with a + # suitable mapping from the uid + if ( + self._grandfathered_mxid_source_attribute + and self._grandfathered_mxid_source_attribute in saml2_auth.ava + ): + attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] + user_id = UserID( + map_username_to_mxid_localpart(attrval), self._hostname + ).to_string() + logger.info( + "Looking for existing account based on mapped %s %s", + self._grandfathered_mxid_source_attribute, + user_id, + ) + + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + registered_user_id = list(users.keys())[0] + logger.info("Grandfathering mapping to %s", registered_user_id) + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id, current_session + + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i, client_redirect_url=client_redirect_url, + ) + + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + raise Exception( + "Error parsing SAML2 response: SAML mapping provider plugin " + "did not return a mxid_localpart value" + ) + + displayname = attribute_dict.get("displayname") + emails = attribute_dict.get("emails", []) + + # Check if this mxid already exists + if not await self._datastore.get_users_by_id_case_insensitive( + UserID(localpart, self._hostname).to_string() + ): + # This mxid is free + break + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) + + logger.info("Mapped SAML user to local part %s", localpart) + + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=displayname, + bind_emails=emails, + ) + + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id, current_session + def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime to_expire = set() @@ -115,9 +291,146 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + """Replace any characters which are not allowed in Matrix IDs with a dot.""" + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} # type: Dict[str, Callable[[str], str]] + + @attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() - # time the session was created, in milliseconds - creation_time = attr.ib() + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + module_api: module api proxy + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + self._grandfathered_mxid_source_attribute = ( + module_api._hs.config.saml2_grandfathered_mxid_source_attribute + ) + + def get_remote_user_id( + self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str + ) -> str: + """Extracts the remote user id from the SAML response""" + try: + return saml_response.ava["uid"][0] + except KeyError: + logger.warning("SAML2 response lacks a 'uid' attestation") + raise SynapseError(400, "'uid' not in SAML2 response") + + def saml_response_to_user_attributes( + self, + saml_response: saml2.response.AuthnResponse, + failures: int, + client_redirect_url: str, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + client_redirect_url: where the client wants to redirect to + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + * emails (list[str]): Any emails for the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + # Retrieve any emails present in the saml response + emails = saml_response.ava.get("email", []) + + return { + "mxid_localpart": localpart, + "displayname": displayname, + "emails": emails, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName", "email"} diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd5e90bacb..4d40d3ac9c 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -18,10 +18,8 @@ import logging from unpaddedbase64 import decode_base64, encode_base64 -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client @@ -35,9 +33,11 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state + self.auth = hs.get_auth() - @defer.inlineCallbacks - def get_old_rooms_from_upgraded_room(self, room_id): + async def get_old_rooms_from_upgraded_room(self, room_id): """Retrieves room IDs of old rooms in the history of an upgraded room. We do so by checking the m.room.create event of the room for a @@ -51,28 +51,42 @@ class SearchHandler(BaseHandler): room_id (str): id of the room to search through. Returns: - Deferred[iterable[unicode]]: predecessor room ids + Deferred[iterable[str]]: predecessor room ids """ historical_room_ids = [] - while True: - predecessor = yield self.store.get_room_predecessor(room_id) + # The initial room must have been known for us to get this far + predecessor = await self.store.get_room_predecessor(room_id) - # If no predecessor, assume we've hit a dead end + while True: if not predecessor: + # We have reached the end of the chain of predecessors + break + + if not isinstance(predecessor.get("room_id"), str): + # This predecessor object is malformed. Exit here + break + + predecessor_room_id = predecessor["room_id"] + + # Don't add it to the list until we have checked that we are in the room + try: + next_predecessor_room = await self.store.get_room_predecessor( + predecessor_room_id + ) + except NotFoundError: + # The predecessor is not a known room, so we are done here break - # Add predecessor's room ID - historical_room_ids.append(predecessor["room_id"]) + historical_room_ids.append(predecessor_room_id) - # Scan through the old room for further predecessors - room_id = predecessor["room_id"] + # And repeat + predecessor = next_predecessor_room return historical_room_ids - @defer.inlineCallbacks - def search(self, user, content, batch=None): + async def search(self, user, content, batch=None): """Performs a full text search for a user. Args: @@ -161,12 +175,12 @@ class SearchHandler(BaseHandler): search_filter = Filter(filter_dict) # TODO: Search through left rooms too - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_local_user_where_membership_is( user.to_string(), membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) - room_ids = set(r.room_id for r in rooms) + room_ids = {r.room_id for r in rooms} # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well @@ -174,7 +188,7 @@ class SearchHandler(BaseHandler): historical_room_ids = [] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist - ids = yield self.get_old_rooms_from_upgraded_room(room_id) + ids = await self.get_old_rooms_from_upgraded_room(room_id) historical_room_ids += ids # Prevent any historical events from being filtered @@ -205,7 +219,7 @@ class SearchHandler(BaseHandler): count = None if order_by == "rank": - search_result = yield self.store.search_msgs(room_ids, search_term, keys) + search_result = await self.store.search_msgs(room_ids, search_term, keys) count = search_result["count"] @@ -220,8 +234,8 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) - events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -249,7 +263,7 @@ class SearchHandler(BaseHandler): # But only go around 5 times since otherwise synapse will be sad. while len(room_events) < search_filter.limit() and i < 5: i += 1 - search_result = yield self.store.search_rooms( + search_result = await self.store.search_rooms( room_ids, search_term, keys, @@ -270,8 +284,8 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) - events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events ) room_events.extend(events) @@ -325,11 +339,11 @@ class SearchHandler(BaseHandler): # If client has asked for "context" for each event (i.e. some surrounding # events and state), fetch that if event_context is not None: - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() contexts = {} for event in allowed_events: - res = yield self.store.get_events_around( + res = await self.store.get_events_around( event.room_id, event.event_id, before_limit, after_limit ) @@ -339,12 +353,12 @@ class SearchHandler(BaseHandler): len(res["events_after"]), ) - res["events_before"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_before"] + res["events_before"] = await filter_events_for_client( + self.storage, user.to_string(), res["events_before"] ) - res["events_after"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_after"] + res["events_after"] = await filter_events_for_client( + self.storage, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( @@ -356,12 +370,12 @@ class SearchHandler(BaseHandler): ).to_string() if include_profile: - senders = set( + senders = { ev.sender for ev in itertools.chain( res["events_before"], [event], res["events_after"] ) - ) + } if res["events_after"]: last_event_id = res["events_after"][-1].event_id @@ -372,7 +386,7 @@ class SearchHandler(BaseHandler): [(EventTypes.Member, sender) for sender in senders] ) - state = yield self.store.get_state_for_event( + state = await self.state_store.get_state_for_event( last_event_id, state_filter ) @@ -394,22 +408,18 @@ class SearchHandler(BaseHandler): time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = ( - yield self._event_serializer.serialize_events( - context["events_before"], time_now - ) + context["events_before"] = await self._event_serializer.serialize_events( + context["events_before"], time_now ) - context["events_after"] = ( - yield self._event_serializer.serialize_events( - context["events_after"], time_now - ) + context["events_after"] = await self._event_serializer.serialize_events( + context["events_after"], time_now ) state_results = {} if include_state: - rooms = set(e.room_id for e in allowed_events) + rooms = {e.room_id for e in allowed_events} for room_id in rooms: - state = yield self.state_handler.get_current_state(room_id) + state = await self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) state_results.values() @@ -423,7 +433,7 @@ class SearchHandler(BaseHandler): { "rank": rank_map[e.event_id], "result": ( - yield self._event_serializer.serialize_event(e, time_now) + await self._event_serializer.serialize_event(e, time_now) ), "context": contexts.get(e.event_id, {}), } @@ -438,7 +448,7 @@ class SearchHandler(BaseHandler): if state_results: s = {} for room_id, state in state_results.items(): - s[room_id] = yield self._event_serializer.serialize_events( + s[room_id] = await self._event_serializer.serialize_events( state, time_now ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index d90c9e0108..4d245b618b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from twisted.internet import defer +from typing import Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.types import Requester from ._base import BaseHandler @@ -30,32 +30,37 @@ class SetPasswordHandler(BaseHandler): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() + self._password_policy_handler = hs.get_password_policy_handler() - @defer.inlineCallbacks - def set_password(self, user_id, newpassword, requester=None): + async def set_password( + self, + user_id: str, + password_hash: str, + logout_devices: bool, + requester: Optional[Requester] = None, + ): if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) - password_hash = yield self._auth_handler.hash(newpassword) - - except_device_id = requester.device_id if requester else None - except_access_token_id = requester.access_token_id if requester else None - try: - yield self.store.user_set_password_hash(user_id, password_hash) + await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise e - # we want to log out all of the user's other sessions. First delete - # all his other devices. - yield self._device_handler.delete_all_devices_for_user( - user_id, except_device_id=except_device_id - ) - - # and now delete any access tokens which weren't associated with - # devices (or were associated with this device). - yield self._auth_handler.delete_access_tokens_for_user( - user_id, except_token_id=except_access_token_id - ) + # Optionally, log out all of the user's other sessions. + if logout_devices: + except_device_id = requester.device_id if requester else None + except_access_token_id = requester.access_token_id if requester else None + + # First delete all of their other devices. + await self._device_handler.delete_all_devices_for_user( + user_id, except_device_id=except_device_id + ) + + # and now delete any access tokens which weren't associated with + # devices (or were associated with this device). + await self._auth_handler.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id + ) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index f065970c40..8590c1eff4 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -24,8 +22,7 @@ class StateDeltasHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. @@ -41,10 +38,10 @@ class StateDeltasHandler(object): prev_event = None event = None if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event and not prev_event: logger.debug("Neither event exists: %r %r", prev_event_id, event_id) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index cbac7c347a..149f861239 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -16,17 +16,14 @@ import logging from collections import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership -from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) -class StatsHandler(StateDeltasHandler): +class StatsHandler: """Handles keeping the *_stats tables updated with a simple time-series of information about the users, rooms and media on the server, such that admins have some idea of who is consuming their resources. @@ -35,7 +32,6 @@ class StatsHandler(StateDeltasHandler): """ def __init__(self, hs): - super(StatsHandler, self).__init__(hs) self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -45,6 +41,8 @@ class StatsHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.stats_bucket_size = hs.config.stats_bucket_size + self.stats_enabled = hs.config.stats_enabled + # The current position in the current_state_delta stream self.pos = None @@ -61,25 +59,23 @@ class StatsHandler(StateDeltasHandler): def notify_new_event(self): """Called when there may be more deltas to process """ - if not self.hs.config.stats_enabled or self._is_processing: + if not self.stats_enabled or self._is_processing: return self._is_processing = True - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False run_as_background_process("stats.notify_new_event", process) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_stats_positions() + self.pos = await self.store.get_stats_positions() # Loop round handling deltas until we're up to date @@ -87,24 +83,29 @@ class StatsHandler(StateDeltasHandler): # Be sure to read the max stream_ordering *before* checking if there are any outstanding # deltas, since there is otherwise a chance that we could miss updates which arrive # after we check the deltas. - room_max_stream_ordering = yield self.store.get_room_max_stream_ordering() + room_max_stream_ordering = self.store.get_room_max_stream_ordering() if self.pos == room_max_stream_ordering: break - deltas = yield self.store.get_current_state_deltas(self.pos) + logger.debug( + "Processing room stats %s->%s", self.pos, room_max_stream_ordering + ) + max_pos, deltas = await self.store.get_current_state_deltas( + self.pos, room_max_stream_ordering + ) if deltas: logger.debug("Handling %d state deltas", len(deltas)) - room_deltas, user_deltas = yield self._handle_deltas(deltas) - - max_pos = deltas[-1]["stream_id"] + room_deltas, user_deltas = await self._handle_deltas(deltas) else: room_deltas = {} user_deltas = {} - max_pos = room_max_stream_ordering # Then count deltas for total_events and total_event_bytes. - room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( + ( + room_count, + user_count, + ) = await self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) @@ -118,7 +119,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("user_deltas: %s", user_deltas) # Always call this so that we update the stats position. - yield self.store.bulk_update_stats_delta( + await self.store.bulk_update_stats_delta( self.clock.time_msec(), updates={"room": room_deltas, "user": user_deltas}, stream_id=max_pos, @@ -130,13 +131,12 @@ class StatsHandler(StateDeltasHandler): self.pos = max_pos - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process Returns: - Deferred[tuple[dict[str, Counter], dict[str, counter]]] - Resovles to two dicts, the room deltas and the user deltas, + tuple[dict[str, Counter], dict[str, counter]] + Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ @@ -155,7 +155,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id) - token = yield self.store.get_earliest_token_for_stats("room", room_id) + token = await self.store.get_earliest_token_for_stats("room", room_id) # If the earliest token to begin from is larger than our current # stream ID, skip processing this delta. @@ -177,7 +177,7 @@ class StatsHandler(StateDeltasHandler): sender = None if event_id is not None: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} sender = event.sender @@ -193,16 +193,16 @@ class StatsHandler(StateDeltasHandler): room_stats_delta["current_state_events"] += 1 if typ == EventTypes.Member: - # we could use _get_key_change here but it's a bit inefficient - # given we're not testing for a specific result; might as well - # just grab the prev_membership and membership strings and - # compare them. + # we could use StateDeltasHandler._get_key_change here but it's + # a bit inefficient given we're not testing for a specific + # result; might as well just grab the prev_membership and + # membership strings and compare them. # We take None rather than leave as a previous membership # in the absence of a previous event because we do not want to # reduce the leave count when a new-to-the-room user joins. prev_membership = None if prev_event_id is not None: - prev_event = yield self.store.get_event( + prev_event = await self.store.get_event( prev_event_id, allow_none=True ) if prev_event: @@ -279,7 +279,7 @@ class StatsHandler(StateDeltasHandler): room_state["history_visibility"] = event_content.get( "history_visibility" ) - elif typ == EventTypes.Encryption: + elif typ == EventTypes.RoomEncryption: room_state["encryption"] = event_content.get("algorithm") elif typ == EventTypes.Name: room_state["name"] = event_content.get("name") @@ -293,6 +293,7 @@ class StatsHandler(StateDeltasHandler): room_state["guest_access"] = event_content.get("guest_access") for room_id, state in room_to_state_updates.items(): - yield self.store.update_room_state(room_id, state) + logger.debug("Updating room_stats_state for %s: %s", room_id, state) + await self.store.update_room_state(room_id, state) return room_to_stats_deltas, user_to_stats_deltas diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 19bca6717f..6bdb24baff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,22 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import itertools import logging +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple from six import iteritems, itervalues +import attr from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership -from synapse.logging.context import LoggingContext +from synapse.api.filtering import FilterCollection +from synapse.events import EventBase +from synapse.logging.context import current_context from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter -from synapse.types import RoomStreamToken +from synapse.types import ( + Collection, + JsonDict, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache @@ -64,17 +72,22 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000 LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 -SyncConfig = collections.namedtuple( - "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"] -) +@attr.s(slots=True, frozen=True) +class SyncConfig: + user = attr.ib(type=UserID) + filter_collection = attr.ib(type=FilterCollection) + is_guest = attr.ib(type=bool) + request_key = attr.ib(type=Tuple[Any, ...]) + device_id = attr.ib(type=str) -class TimelineBatch( - collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"]) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class TimelineBatch: + prev_batch = attr.ib(type=StreamToken) + events = attr.ib(type=List[EventBase]) + limited = attr.ib(bool) - def __nonzero__(self): + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -83,23 +96,17 @@ class TimelineBatch( __bool__ = __nonzero__ # python3 -class JoinedSyncResult( - collections.namedtuple( - "JoinedSyncResult", - [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "ephemeral", - "account_data", - "unread_notifications", - "summary", - ], - ) -): - __slots__ = [] - - def __nonzero__(self): +@attr.s(slots=True, frozen=True) +class JoinedSyncResult: + room_id = attr.ib(type=str) + timeline = attr.ib(type=TimelineBatch) + state = attr.ib(type=StateMap[EventBase]) + ephemeral = attr.ib(type=List[JsonDict]) + account_data = attr.ib(type=List[JsonDict]) + unread_notifications = attr.ib(type=JsonDict) + summary = attr.ib(type=Optional[JsonDict]) + + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -115,20 +122,14 @@ class JoinedSyncResult( __bool__ = __nonzero__ # python3 -class ArchivedSyncResult( - collections.namedtuple( - "ArchivedSyncResult", - [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "account_data", - ], - ) -): - __slots__ = [] - - def __nonzero__(self): +@attr.s(slots=True, frozen=True) +class ArchivedSyncResult: + room_id = attr.ib(type=str) + timeline = attr.ib(type=TimelineBatch) + state = attr.ib(type=StateMap[EventBase]) + account_data = attr.ib(type=List[JsonDict]) + + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ @@ -137,70 +138,88 @@ class ArchivedSyncResult( __bool__ = __nonzero__ # python3 -class InvitedSyncResult( - collections.namedtuple( - "InvitedSyncResult", - ["room_id", "invite"], # str # FrozenEvent: the invite event - ) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class InvitedSyncResult: + room_id = attr.ib(type=str) + invite = attr.ib(type=EventBase) - def __nonzero__(self): + def __nonzero__(self) -> bool: """Invited rooms should always be reported to the client""" return True __bool__ = __nonzero__ # python3 -class GroupsSyncResult( - collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"]) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class GroupsSyncResult: + join = attr.ib(type=JsonDict) + invite = attr.ib(type=JsonDict) + leave = attr.ib(type=JsonDict) - def __nonzero__(self): + def __nonzero__(self) -> bool: return bool(self.join or self.invite or self.leave) __bool__ = __nonzero__ # python3 -class DeviceLists( - collections.namedtuple( - "DeviceLists", - [ - "changed", # list of user_ids whose devices may have changed - "left", # list of user_ids whose devices we no longer track - ], - ) -): - __slots__ = [] +@attr.s(slots=True, frozen=True) +class DeviceLists: + """ + Attributes: + changed: List of user_ids whose devices may have changed + left: List of user_ids whose devices we no longer track + """ + + changed = attr.ib(type=Collection[str]) + left = attr.ib(type=Collection[str]) - def __nonzero__(self): + def __nonzero__(self) -> bool: return bool(self.changed or self.left) __bool__ = __nonzero__ # python3 -class SyncResult( - collections.namedtuple( - "SyncResult", - [ - "next_batch", # Token for the next sync - "presence", # List of presence events for the user. - "account_data", # List of account_data events for the user. - "joined", # JoinedSyncResult for each joined room. - "invited", # InvitedSyncResult for each invited room. - "archived", # ArchivedSyncResult for each archived room. - "to_device", # List of direct messages for the device. - "device_lists", # List of user_ids whose devices have changed - "device_one_time_keys_count", # Dict of algorithm to count for one time keys - # for this device - "groups", - ], - ) -): - __slots__ = [] - - def __nonzero__(self): +@attr.s +class _RoomChanges: + """The set of room entries to include in the sync, plus the set of joined + and left room IDs since last sync. + """ + + room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) + invited = attr.ib(type=List[InvitedSyncResult]) + newly_joined_rooms = attr.ib(type=List[str]) + newly_left_rooms = attr.ib(type=List[str]) + + +@attr.s(slots=True, frozen=True) +class SyncResult: + """ + Attributes: + next_batch: Token for the next sync + presence: List of presence events for the user. + account_data: List of account_data events for the user. + joined: JoinedSyncResult for each joined room. + invited: InvitedSyncResult for each invited room. + archived: ArchivedSyncResult for each archived room. + to_device: List of direct messages for the device. + device_lists: List of user_ids whose devices have changed + device_one_time_keys_count: Dict of algorithm to count for one time keys + for this device + groups: Group updates, if any + """ + + next_batch = attr.ib(type=StreamToken) + presence = attr.ib(type=List[JsonDict]) + account_data = attr.ib(type=List[JsonDict]) + joined = attr.ib(type=List[JoinedSyncResult]) + invited = attr.ib(type=List[InvitedSyncResult]) + archived = attr.ib(type=List[ArchivedSyncResult]) + to_device = attr.ib(type=List[JsonDict]) + device_lists = attr.ib(type=DeviceLists) + device_one_time_keys_count = attr.ib(type=JsonDict) + groups = attr.ib(type=Optional[GroupsSyncResult]) + + def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used to tell if the notifier needs to wait for more events when polling for events. @@ -230,6 +249,8 @@ class SyncHandler(object): self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() self.auth = hs.get_auth() + self.storage = hs.get_storage() + self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -239,23 +260,24 @@ class SyncHandler(object): expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) - @defer.inlineCallbacks - def wait_for_sync_for_user( - self, sync_config, since_token=None, timeout=0, full_state=False - ): + async def wait_for_sync_for_user( + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + timeout: int = 0, + full_state: bool = False, + ) -> SyncResult: """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. - Returns: - Deferred[SyncResult] """ # If the user is not part of the mau group, then check that limits have # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) - res = yield self.response_cache.wrap( + res = await self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, sync_config, @@ -265,8 +287,13 @@ class SyncHandler(object): ) return res - @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): + async def _wait_for_sync_for_user( + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + timeout: int = 0, + full_state: bool = False, + ) -> SyncResult: if since_token is None: sync_type = "initial_sync" elif full_state: @@ -274,14 +301,14 @@ class SyncHandler(object): else: sync_type = "incremental_sync" - context = LoggingContext.current_context() + context = current_context() if context: context.tag = sync_type if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = yield self.current_sync_for_user( + result = await self.current_sync_for_user( sync_config, since_token, full_state=full_state ) else: @@ -289,7 +316,7 @@ class SyncHandler(object): def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) - result = yield self.notifier.wait_for_events( + result = await self.notifier.wait_for_events( sync_config.user.to_string(), timeout, current_sync_callback, @@ -305,27 +332,33 @@ class SyncHandler(object): return result - def current_sync_for_user(self, sync_config, since_token=None, full_state=False): + async def current_sync_for_user( + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + full_state: bool = False, + ) -> SyncResult: """Get the sync for client needed to match what the server has now. - Returns: - A Deferred SyncResult. """ - return self.generate_sync_result(sync_config, since_token, full_state) + return await self.generate_sync_result(sync_config, since_token, full_state) - @defer.inlineCallbacks - def push_rules_for_user(self, user): + async def push_rules_for_user(self, user: UserID) -> JsonDict: user_id = user.to_string() - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(user, rules) return rules - @defer.inlineCallbacks - def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): + async def ephemeral_by_room( + self, + sync_result_builder: "SyncResultBuilder", + now_token: StreamToken, + since_token: Optional[StreamToken] = None, + ) -> Tuple[StreamToken, Dict[str, List[JsonDict]]]: """Get the ephemeral events for each room the user is in Args: - sync_result_builder(SyncResultBuilder) - now_token (StreamToken): Where the server is currently up to. - since_token (StreamToken): Where the server was when the client + sync_result_builder + now_token: Where the server is currently up to. + since_token: Where the server was when the client last synced. Returns: A tuple of the now StreamToken, updated to reflect the which typing @@ -341,7 +374,7 @@ class SyncHandler(object): room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events( + typing, typing_key = await typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -350,7 +383,7 @@ class SyncHandler(object): ) now_token = now_token.copy_and_replace("typing_key", typing_key) - ephemeral_by_room = {} + ephemeral_by_room = {} # type: JsonDict for event in typing: # we want to exclude the room_id from the event, but modifying the @@ -363,7 +396,7 @@ class SyncHandler(object): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events( + receipts, receipt_key = await receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -380,16 +413,15 @@ class SyncHandler(object): return now_token, ephemeral_by_room - @defer.inlineCallbacks - def _load_filtered_recents( + async def _load_filtered_recents( self, - room_id, - sync_config, - now_token, - since_token=None, - recents=None, - newly_joined_room=False, - ): + room_id: str, + sync_config: SyncConfig, + now_token: StreamToken, + since_token: Optional[StreamToken] = None, + potential_recents: Optional[List[EventBase]] = None, + newly_joined_room: bool = False, + ) -> TimelineBatch: """ Returns: a Deferred TimelineBatch @@ -400,24 +432,32 @@ class SyncHandler(object): sync_config.filter_collection.blocks_all_room_timeline() ) - if recents is None or newly_joined_room or timeline_limit < len(recents): + if ( + potential_recents is None + or newly_joined_room + or timeline_limit < len(potential_recents) + ): limited = True else: limited = False - if recents: - recents = sync_config.filter_collection.filter_room_timeline(recents) + if potential_recents: + recents = sync_config.filter_collection.filter_room_timeline( + potential_recents + ) # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to # ensure that we always include current state in the timeline - current_state_ids = frozenset() + current_state_ids = frozenset() # type: FrozenSet[str] if any(e.is_state() for e in recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) - current_state_ids = frozenset(itervalues(current_state_ids)) + current_state_ids_map = await self.state.get_current_state_ids( + room_id + ) + current_state_ids = frozenset(itervalues(current_state_ids_map)) - recents = yield filter_events_for_client( - self.store, + recents = await filter_events_for_client( + self.storage, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -447,14 +487,14 @@ class SyncHandler(object): # Otherwise, we want to return the last N events in the room # in toplogical ordering. if since_key: - events, end_key = yield self.store.get_room_events_stream_for_room( + events, end_key = await self.store.get_room_events_stream_for_room( room_id, limit=load_limit + 1, from_key=since_key, to_key=end_key, ) else: - events, end_key = yield self.store.get_recent_events_for_room( + events, end_key = await self.store.get_recent_events_for_room( room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( @@ -466,11 +506,13 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) - current_state_ids = frozenset(itervalues(current_state_ids)) + current_state_ids_map = await self.state.get_current_state_ids( + room_id + ) + current_state_ids = frozenset(itervalues(current_state_ids_map)) - loaded_recents = yield filter_events_for_client( - self.store, + loaded_recents = await filter_events_for_client( + self.storage, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -496,20 +538,17 @@ class SyncHandler(object): limited=limited or newly_joined_room, ) - @defer.inlineCallbacks - def get_state_after_event(self, event, state_filter=StateFilter.all()): + async def get_state_after_event( + self, event: EventBase, state_filter: StateFilter = StateFilter.all() + ) -> StateMap[str]: """ Get the room state after the given event Args: - event(synapse.events.EventBase): event of interest - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A Deferred map from ((type, state_key)->Event) + event: event of interest + state_filter: The state filter used to fetch state from the database. """ - state_ids = yield self.store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -517,30 +556,30 @@ class SyncHandler(object): state_ids[(event.type, event.state_key)] = event.event_id return state_ids - @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): + async def get_state_at( + self, + room_id: str, + stream_position: StreamToken, + state_filter: StateFilter = StateFilter.all(), + ) -> StateMap[str]: """ Get the room state at a particular stream position Args: - room_id(str): room for which to get state - stream_position(StreamToken): point at which to get state - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A Deferred map from ((type, state_key)->Event) + room_id: room for which to get state + stream_position: point at which to get state + state_filter: The state filter used to fetch state from the database. """ # FIXME this claims to get the state at a stream position, but # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] - state = yield self.get_state_after_event( + state = await self.get_state_after_event( last_event, state_filter=state_filter ) @@ -549,30 +588,31 @@ class SyncHandler(object): state = {} return state - @defer.inlineCallbacks - def compute_summary(self, room_id, sync_config, batch, state, now_token): + async def compute_summary( + self, + room_id: str, + sync_config: SyncConfig, + batch: TimelineBatch, + state: StateMap[EventBase], + now_token: StreamToken, + ) -> Optional[JsonDict]: """ Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds state events to 'state' if needed to describe the heroes. - Args: - room_id(str): - sync_config(synapse.handlers.sync.SyncConfig): - batch(synapse.handlers.sync.TimelineBatch): The timeline batch for - the room that will be sent to the user. - state(dict): dict of (type, state_key) -> Event as returned by - compute_state_delta - now_token(str): Token of the end of the current batch. - - Returns: - A deferred dict describing the room summary + Args + room_id + sync_config + batch: The timeline batch for the room that will be sent to the user. + state: State as returned by compute_state_delta + now_token: Token of the end of the current batch. """ # FIXME: we could/should get this from room_stats when matthew/stats lands # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 - last_events, _ = yield self.store.get_recent_event_ids_for_room( + last_events, _ = await self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1 ) @@ -580,7 +620,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -588,7 +628,7 @@ class SyncHandler(object): ) # this is heavily cached, thus: fast. - details = yield self.store.get_room_summary(room_id) + details = await self.store.get_room_summary(room_id) name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) @@ -606,12 +646,12 @@ class SyncHandler(object): # calculating heroes. Empty strings are falsey, so we check # for the "name" value and default to an empty string. if name_id: - name = yield self.store.get_event(name_id, allow_none=True) + name = await self.store.get_event(name_id, allow_none=True) if name and name.content.get("name"): return summary if canonical_alias_id: - canonical_alias = yield self.store.get_event( + canonical_alias = await self.store.get_event( canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): @@ -642,11 +682,9 @@ class SyncHandler(object): # FIXME: order by stream ordering rather than as returned by SQL if joined_user_ids or invited_user_ids: - summary["m.heroes"] = sorted( - [user_id for user_id in (joined_user_ids + invited_user_ids)] - )[0:5] + summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5] else: - summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] + summary["m.heroes"] = sorted(gone_user_ids)[0:5] if not sync_config.filter_collection.lazy_load_members(): return summary @@ -657,9 +695,9 @@ class SyncHandler(object): # track which members the client should already know about via LL: # Ones which are already in state... - existing_members = set( + existing_members = { user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member - ) + } # ...or ones which are in the timeline... for ev in batch.events: @@ -676,7 +714,7 @@ class SyncHandler(object): ) ] - missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = await self.store.get_events(missing_hero_event_ids) missing_hero_state = missing_hero_state.values() for s in missing_hero_state: @@ -685,7 +723,7 @@ class SyncHandler(object): return summary - def get_lazy_loaded_members_cache(self, cache_key): + def get_lazy_loaded_members_cache(self, cache_key: Tuple[str, str]) -> LruCache: cache = self.lazy_loaded_members_cache.get(cache_key) if cache is None: logger.debug("creating LruCache for %r", cache_key) @@ -695,25 +733,25 @@ class SyncHandler(object): logger.debug("found LruCache for %r", cache_key) return cache - @defer.inlineCallbacks - def compute_state_delta( - self, room_id, batch, sync_config, since_token, now_token, full_state - ): + async def compute_state_delta( + self, + room_id: str, + batch: TimelineBatch, + sync_config: SyncConfig, + since_token: Optional[StreamToken], + now_token: StreamToken, + full_state: bool, + ) -> StateMap[EventBase]: """ Works out the difference in state between the start of the timeline and the previous sync. Args: - room_id(str): - batch(synapse.handlers.sync.TimelineBatch): The timeline batch for - the room that will be sent to the user. - sync_config(synapse.handlers.sync.SyncConfig): - since_token(str|None): Token of the end of the previous batch. May - be None. - now_token(str): Token of the end of the current batch. - full_state(bool): Whether to force returning the full state. - - Returns: - A deferred dict of (type, state_key) -> Event + room_id: + batch: The timeline batch for the room that will be sent to the user. + sync_config: + since_token: Token of the end of the previous batch. May be None. + now_token: Token of the end of the current batch. + full_state: Whether to force returning the full state. """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -733,10 +771,10 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - members_to_fetch = set( + members_to_fetch = { event.sender # FIXME: we also care about invite targets etc. for event in batch.events - ) + } if full_state: # always make sure we LL ourselves so we know we're in the room @@ -757,16 +795,16 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -781,13 +819,13 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.store.get_state_ids_for_event( + state_at_timeline_start = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: # We can get here if the user has ignored the senders of all # the recent events. - state_at_timeline_start = yield self.get_state_at( + state_at_timeline_start = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -805,19 +843,23 @@ class SyncHandler(object): # about them). state_filter = StateFilter.all() - state_at_previous_sync = yield self.get_state_at( + # If this is an initial sync then full_state should be set, and + # that case is handled above. We assert here to ensure that this + # is indeed the case. + assert since_token is not None + state_at_previous_sync = await self.get_state_at( room_id, stream_position=since_token, state_filter=state_filter ) if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: # Its not clear how we get here, but empirically we do # (#5407). Logging has been added elsewhere to try and # figure out where this state comes from. - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -841,7 +883,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -879,29 +921,30 @@ class SyncHandler(object): if t[0] == EventTypes.Member: cache.set(t[1], event_id) - state = {} + state = {} # type: Dict[str, EventBase] if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) return { (e.type, e.state_key): e for e in sync_config.filter_collection.filter_room_state( list(state.values()) ) + if e.type != EventTypes.Aliases # until MSC2261 or alternative solution } - @defer.inlineCallbacks - def unread_notifs_for_room_id(self, room_id, sync_config): + async def unread_notifs_for_room_id( + self, room_id: str, sync_config: SyncConfig + ) -> Optional[Dict[str, str]]: with Measure(self.clock, "unread_notifs_for_room_id"): - last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( + last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, receipt_type="m.read", ) - notifs = [] if last_unread_event_id: - notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id ) return notifs @@ -910,25 +953,21 @@ class SyncHandler(object): # count is whatever it was last time. return None - @defer.inlineCallbacks - def generate_sync_result(self, sync_config, since_token=None, full_state=False): + async def generate_sync_result( + self, + sync_config: SyncConfig, + since_token: Optional[StreamToken] = None, + full_state: bool = False, + ) -> SyncResult: """Generates a sync result. - - Args: - sync_config (SyncConfig) - since_token (StreamToken) - full_state (bool) - - Returns: - Deferred(SyncResult) """ # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = yield self.event_sources.get_current_token() + now_token = await self.event_sources.get_current_token() - logger.info( + logger.debug( "Calculating sync response for %r between %s and %s", sync_config.user, since_token, @@ -942,10 +981,9 @@ class SyncHandler(object): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - joined_room_ids = yield self.get_rooms_for_user_at( + joined_room_ids = await self.get_rooms_for_user_at( user_id, now_token.room_stream_id ) - sync_result_builder = SyncResultBuilder( sync_config, full_state, @@ -954,11 +992,11 @@ class SyncHandler(object): joined_room_ids=joined_room_ids, ) - account_data_by_room = yield self._generate_sync_entry_for_account_data( + account_data_by_room = await self._generate_sync_entry_for_account_data( sync_result_builder ) - res = yield self._generate_sync_entry_for_rooms( + res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) newly_joined_rooms, newly_joined_or_invited_users, _, _ = res @@ -968,13 +1006,13 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: - yield self._generate_sync_entry_for_presence( + await self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) - yield self._generate_sync_entry_for_to_device(sync_result_builder) + await self._generate_sync_entry_for_to_device(sync_result_builder) - device_lists = yield self._generate_sync_entry_for_device_list( + device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, newly_joined_or_invited_users=newly_joined_or_invited_users, @@ -983,13 +1021,13 @@ class SyncHandler(object): ) device_id = sync_config.device_id - one_time_key_counts = {} + one_time_key_counts = {} # type: JsonDict if device_id: - one_time_key_counts = yield self.store.count_e2e_one_time_keys( + one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - yield self._generate_sync_entry_for_groups(sync_result_builder) + await self._generate_sync_entry_for_groups(sync_result_builder) # debug for https://github.com/matrix-org/synapse/issues/4422 for joined_room in sync_result_builder.joined: @@ -1013,18 +1051,19 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_groups") - @defer.inlineCallbacks - def _generate_sync_entry_for_groups(self, sync_result_builder): + async def _generate_sync_entry_for_groups( + self, sync_result_builder: "SyncResultBuilder" + ) -> None: user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token if since_token and since_token.groups_key: - results = yield self.store.get_groups_changes_for_user( + results = await self.store.get_groups_changes_for_user( user_id, since_token.groups_key, now_token.groups_key ) else: - results = yield self.store.get_all_groups_for_user( + results = await self.store.get_all_groups_for_user( user_id, now_token.groups_key ) @@ -1057,30 +1096,24 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_device_list") - @defer.inlineCallbacks - def _generate_sync_entry_for_device_list( + async def _generate_sync_entry_for_device_list( self, - sync_result_builder, - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, - newly_left_users, - ): + sync_result_builder: "SyncResultBuilder", + newly_joined_rooms: Set[str], + newly_joined_or_invited_users: Set[str], + newly_left_rooms: Set[str], + newly_left_users: Set[str], + ) -> DeviceLists: """Generate the DeviceLists section of sync Args: - sync_result_builder (SyncResultBuilder) - newly_joined_rooms (set[str]): Set of rooms user has joined since - previous sync - newly_joined_or_invited_users (set[str]): Set of users that have - joined or been invited to a room since previous sync. - newly_left_rooms (set[str]): Set of rooms user has left since + sync_result_builder + newly_joined_rooms: Set of rooms user has joined since previous sync + newly_joined_or_invited_users: Set of users that have joined or + been invited to a room since previous sync. + newly_left_rooms: Set of rooms user has left since previous sync + newly_left_users: Set of users that have left a room we're in since previous sync - newly_left_users (set[str]): Set of users that have left a room - we're in since previous sync - - Returns: - Deferred[DeviceLists] """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1106,27 +1139,41 @@ class SyncHandler(object): # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) + + tracked_users = users_who_share_room + # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = yield self.store.get_users_whose_devices_changed( - since_token.device_list_key, users_who_share_room + users_that_have_changed = await self.store.get_users_whose_devices_changed( + since_token.device_list_key, tracked_users ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = yield self.state.get_current_users_in_room(room_id) + joined_users = await self.state.get_current_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) + user_signatures_changed = await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) + users_that_have_changed.update(user_signatures_changed) + # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = yield self.state.get_current_users_in_room(room_id) + left_users = await self.state.get_current_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1136,16 +1183,11 @@ class SyncHandler(object): else: return DeviceLists(changed=[], left=[]) - @defer.inlineCallbacks - def _generate_sync_entry_for_to_device(self, sync_result_builder): + async def _generate_sync_entry_for_to_device( + self, sync_result_builder: "SyncResultBuilder" + ) -> None: """Generates the portion of the sync response. Populates `sync_result_builder` with the result. - - Args: - sync_result_builder(SyncResultBuilder) - - Returns: - Deferred(dict): A dictionary containing the per room account data. """ user_id = sync_result_builder.sync_config.user.to_string() device_id = sync_result_builder.sync_config.device_id @@ -1158,14 +1200,14 @@ class SyncHandler(object): # We only delete messages when a new message comes in, but that's # fine so long as we delete them at some point. - deleted = yield self.store.delete_messages_for_device( + deleted = await self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) logger.debug( "Deleted %d to-device messages up to %d", deleted, since_stream_id ) - messages, stream_id = yield self.store.get_new_messages_for_device( + messages, stream_id = await self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) @@ -1183,42 +1225,45 @@ class SyncHandler(object): else: sync_result_builder.to_device = [] - @defer.inlineCallbacks - def _generate_sync_entry_for_account_data(self, sync_result_builder): + async def _generate_sync_entry_for_account_data( + self, sync_result_builder: "SyncResultBuilder" + ) -> Dict[str, Dict[str, JsonDict]]: """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. Args: - sync_result_builder(SyncResultBuilder) + sync_result_builder Returns: - Deferred(dict): A dictionary containing the per room account data. + A dictionary containing the per room account data. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, since_token.account_data_key - ) + ( + account_data, + account_data_by_room, + ) = await self.store.get_updated_account_data_for_user( + user_id, since_token.account_data_key ) - push_rules_changed = yield self.store.have_push_rules_changed_for_user( + push_rules_changed = await self.store.have_push_rules_changed_for_user( user_id, int(since_token.push_rules_key) ) if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(sync_config.user.to_string()) - ) + ( + account_data, + account_data_by_room, + ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1233,20 +1278,22 @@ class SyncHandler(object): return account_data_by_room - @defer.inlineCallbacks - def _generate_sync_entry_for_presence( - self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users - ): + async def _generate_sync_entry_for_presence( + self, + sync_result_builder: "SyncResultBuilder", + newly_joined_rooms: Set[str], + newly_joined_or_invited_users: Set[str], + ) -> None: """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. Args: - sync_result_builder(SyncResultBuilder) - newly_joined_rooms(list): List of rooms that the user has joined - since the last sync (or empty if an initial sync) - newly_joined_or_invited_users(list): List of users that have joined - or been invited to rooms since the last sync (or empty if an initial - sync) + sync_result_builder + newly_joined_rooms: Set of rooms that the user has joined since + the last sync (or empty if an initial sync) + newly_joined_or_invited_users: Set of users that have joined or + been invited to rooms since the last sync (or empty if an + initial sync) """ now_token = sync_result_builder.now_token sync_config = sync_result_builder.sync_config @@ -1262,7 +1309,7 @@ class SyncHandler(object): presence_key = None include_offline = False - presence, presence_key = yield presence_source.get_new_events( + presence, presence_key = await presence_source.get_new_events( user=user, from_key=presence_key, is_guest=sync_config.is_guest, @@ -1274,12 +1321,12 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states(extra_users_ids) + states = await self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user @@ -1289,17 +1336,20 @@ class SyncHandler(object): sync_result_builder.presence = presence - @defer.inlineCallbacks - def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + async def _generate_sync_entry_for_rooms( + self, + sync_result_builder: "SyncResultBuilder", + account_data_by_room: Dict[str, Dict[str, JsonDict]], + ) -> Tuple[Set[str], Set[str], Set[str], Set[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. Args: - sync_result_builder(SyncResultBuilder) - account_data_by_room(dict): Dictionary of per room account data + sync_result_builder + account_data_by_room: Dictionary of per room account data Returns: - Deferred(tuple): Returns a 4-tuple of + Returns a 4-tuple of `(newly_joined_rooms, newly_joined_or_invited_users, newly_left_rooms, newly_left_users)` """ @@ -1310,9 +1360,9 @@ class SyncHandler(object): ) if block_all_room_ephemeral: - ephemeral_by_room = {} + ephemeral_by_room = {} # type: Dict[str, List[JsonDict]] else: - now_token, ephemeral_by_room = yield self.ephemeral_by_room( + now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, now_token=sync_result_builder.now_token, since_token=sync_result_builder.since_token, @@ -1320,20 +1370,20 @@ class SyncHandler(object): sync_result_builder.now_token = now_token # We check up front if anything has changed, if it hasn't then there is - # no point in going futher. + # no point in going further. since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: - have_changed = yield self._have_rooms_changed(sync_result_builder) + have_changed = await self._have_rooms_changed(sync_result_builder) if not have_changed: - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") - return [], [], [], [] + return set(), set(), set(), set() - ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + ignored_account_data = await self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id ) @@ -1343,18 +1393,21 @@ class SyncHandler(object): ignored_users = frozenset() if since_token: - res = yield self._get_rooms_changed(sync_result_builder, ignored_users) - room_entries, invited, newly_joined_rooms, newly_left_rooms = res - - tags_by_room = yield self.store.get_updated_tags( + room_changes = await self._get_rooms_changed( + sync_result_builder, ignored_users + ) + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - res = yield self._get_all_rooms(sync_result_builder, ignored_users) - room_entries, invited, newly_joined_rooms = res - newly_left_rooms = [] + room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) + + room_entries = room_changes.room_entries + invited = room_changes.invited + newly_joined_rooms = room_changes.newly_joined_rooms + newly_left_rooms = room_changes.newly_left_rooms def handle_room_entries(room_entry): return self._generate_room_entry( @@ -1367,7 +1420,7 @@ class SyncHandler(object): always_include=sync_result_builder.full_state, ) - yield concurrently_execute(handle_room_entries, room_entries, 10) + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) @@ -1395,14 +1448,15 @@ class SyncHandler(object): newly_left_users -= newly_joined_or_invited_users return ( - newly_joined_rooms, + set(newly_joined_rooms), newly_joined_or_invited_users, - newly_left_rooms, + set(newly_left_rooms), newly_left_users, ) - @defer.inlineCallbacks - def _have_rooms_changed(self, sync_result_builder): + async def _have_rooms_changed( + self, sync_result_builder: "SyncResultBuilder" + ) -> bool: """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. """ @@ -1413,7 +1467,7 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1426,23 +1480,10 @@ class SyncHandler(object): return True return False - @defer.inlineCallbacks - def _get_rooms_changed(self, sync_result_builder, ignored_users): + async def _get_rooms_changed( + self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + ) -> _RoomChanges: """Gets the the changes that have happened since the last sync. - - Args: - sync_result_builder(SyncResultBuilder) - ignored_users(set(str)): Set of users ignored by user. - - Returns: - Deferred(tuple): Returns a tuple of the form: - `(room_entries, invited_rooms, newly_joined_rooms, newly_left_rooms)` - - where: - room_entries is a list [RoomSyncResultBuilder] - invited_rooms is a list [InvitedSyncResult] - newly_joined_rooms is a list[str] of room ids - newly_left_rooms is a list[str] of room ids """ user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1452,11 +1493,11 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) - mem_change_events_by_room_id = {} + mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]] for event in rooms_changed: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) @@ -1465,7 +1506,7 @@ class SyncHandler(object): room_entries = [] invited = [] for room_id, events in iteritems(mem_change_events_by_room_id): - logger.info( + logger.debug( "Membership changes in %s: [%s]", room_id, ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)), @@ -1490,11 +1531,11 @@ class SyncHandler(object): continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) @@ -1527,13 +1568,13 @@ class SyncHandler(object): newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) if old_mem_ev and old_mem_ev.membership == Membership.JOIN: @@ -1557,7 +1598,7 @@ class SyncHandler(object): if leave_events: leave_event = leave_events[-1] - leave_stream_token = yield self.store.get_stream_token_for_event( + leave_stream_token = await self.store.get_stream_token_for_event( leave_event.event_id ) leave_token = since_token.copy_and_replace( @@ -1575,7 +1616,7 @@ class SyncHandler(object): # This is all screaming out for a refactor, as the logic here is # subtle and the moving parts numerous. if leave_event.internal_metadata.is_out_of_band_membership(): - batch_events = [leave_event] + batch_events = [leave_event] # type: Optional[List[EventBase]] else: batch_events = None @@ -1594,7 +1635,7 @@ class SyncHandler(object): timeline_limit = sync_config.filter_collection.timeline_limit() # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, @@ -1602,7 +1643,7 @@ class SyncHandler(object): ) # We loop through all room ids, even if there are no new events, in case - # there are non room events taht we need to notify about. + # there are non room events that we need to notify about. for room_id in sync_result_builder.joined_room_ids: room_entry = room_to_events.get(room_id, None) @@ -1641,19 +1682,17 @@ class SyncHandler(object): ) room_entries.append(entry) - return room_entries, invited, newly_joined_rooms, newly_left_rooms + return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) - @defer.inlineCallbacks - def _get_all_rooms(self, sync_result_builder, ignored_users): + async def _get_all_rooms( + self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + ) -> _RoomChanges: """Returns entries for all rooms for the user. Args: - sync_result_builder(SyncResultBuilder) - ignored_users(set(str)): Set of users ignored by user. + sync_result_builder + ignored_users: Set of users ignored by user. - Returns: - Deferred(tuple): Returns a tuple of the form: - `([RoomSyncResultBuilder], [InvitedSyncResult], [])` """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1668,7 +1707,7 @@ class SyncHandler(object): Membership.BAN, ) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=membership_list ) @@ -1691,7 +1730,7 @@ class SyncHandler(object): elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue - invite = yield self.store.get_event(event.event_id) + invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. @@ -1715,31 +1754,30 @@ class SyncHandler(object): ) ) - return room_entries, invited, [] + return _RoomChanges(room_entries, invited, [], []) - @defer.inlineCallbacks - def _generate_room_entry( + async def _generate_room_entry( self, - sync_result_builder, - ignored_users, - room_builder, - ephemeral, - tags, - account_data, - always_include=False, + sync_result_builder: "SyncResultBuilder", + ignored_users: Set[str], + room_builder: "RoomSyncResultBuilder", + ephemeral: List[JsonDict], + tags: Optional[List[JsonDict]], + account_data: Dict[str, JsonDict], + always_include: bool = False, ): """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. Args: - sync_result_builder(SyncResultBuilder) - ignored_users(set(str)): Set of users ignored by user. - room_builder(RoomSyncResultBuilder) - ephemeral(list): List of new ephemeral events for room - tags(list): List of *all* tags for room, or None if there has been + sync_result_builder + ignored_users: Set of users ignored by user. + room_builder + ephemeral: List of new ephemeral events for room + tags: List of *all* tags for room, or None if there has been no change. - account_data(list): List of new account data for room - always_include(bool): Always include this room in the sync response, + account_data: List of new account data for room + always_include: Always include this room in the sync response, even if empty. """ newly_joined = room_builder.newly_joined @@ -1760,12 +1798,12 @@ class SyncHandler(object): since_token = room_builder.since_token upto_token = room_builder.upto_token - batch = yield self._load_filtered_recents( + batch = await self._load_filtered_recents( room_id, sync_config, now_token=upto_token, since_token=since_token, - recents=events, + potential_recents=events, newly_joined_room=newly_joined, ) @@ -1787,7 +1825,7 @@ class SyncHandler(object): # tag was added by synapse e.g. for server notice rooms. if full_state: user_id = sync_result_builder.sync_config.user.to_string() - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) # If there aren't any tags, don't send the empty tags list down # sync @@ -1812,11 +1850,11 @@ class SyncHandler(object): ): return - state = yield self.compute_state_delta( + state = await self.compute_state_delta( room_id, batch, sync_config, since_token, now_token, full_state=full_state ) - summary = {} + summary = {} # type: Optional[JsonDict] # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form @@ -1835,12 +1873,12 @@ class SyncHandler(object): ) or since_token is None ): - summary = yield self.compute_summary( + summary = await self.compute_summary( room_id, sync_config, batch, state, now_token ) if room_builder.rtype == "joined": - unread_notifications = {} + unread_notifications = {} # type: Dict[str, str] room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1852,7 +1890,7 @@ class SyncHandler(object): ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) + notifs = await self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1862,24 +1900,25 @@ class SyncHandler(object): if batch.limited and since_token: user_id = sync_result_builder.sync_config.user.to_string() - logger.info( + logger.debug( "Incremental gappy sync of %s for user %s with %d state events" % (room_id, user_id, len(state)) ) elif room_builder.rtype == "archived": - room_sync = ArchivedSyncResult( + archived_room_sync = ArchivedSyncResult( room_id=room_id, timeline=batch, state=state, account_data=account_data_events, ) - if room_sync or always_include: - sync_result_builder.archived.append(room_sync) + if archived_room_sync or always_include: + sync_result_builder.archived.append(archived_room_sync) else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) - @defer.inlineCallbacks - def get_rooms_for_user_at(self, user_id, stream_ordering): + async def get_rooms_for_user_at( + self, user_id: str, stream_ordering: int + ) -> FrozenSet[str]: """Get set of joined rooms for a user at the given stream ordering. The stream ordering *must* be recent, otherwise this may throw an @@ -1887,14 +1926,13 @@ class SyncHandler(object): current token, which should be perfectly fine). Args: - user_id (str) - stream_ordering (int) + user_id + stream_ordering ReturnValue: - Deferred[frozenset[str]]: Set of room_ids the user is in at given - stream_ordering. + Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) + joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1912,18 +1950,17 @@ class SyncHandler(object): logger.info("User joined room after current token: %s", room_id) - extrems = yield self.store.get_forward_extremeties_for_room( + extrems = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering ) - users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) - joined_room_ids = frozenset(joined_room_ids) - return joined_room_ids + return frozenset(joined_room_ids) -def _action_has_highlight(actions): +def _action_has_highlight(actions: List[JsonDict]) -> bool: for action in actions: try: if action.get("set_tweak", None) == "highlight": @@ -1935,22 +1972,23 @@ def _action_has_highlight(actions): def _calculate_state( - timeline_contains, timeline_start, previous, current, lazy_load_members -): + timeline_contains: StateMap[str], + timeline_start: StateMap[str], + previous: StateMap[str], + current: StateMap[str], + lazy_load_members: bool, +) -> StateMap[str]: """Works out what state to include in a sync response. Args: - timeline_contains (dict): state in the timeline - timeline_start (dict): state at the start of the timeline - previous (dict): state at the end of the previous sync (or empty dict + timeline_contains: state in the timeline + timeline_start: state at the start of the timeline + previous: state at the end of the previous sync (or empty dict if this is an initial sync) - current (dict): state at the end of the timeline - lazy_load_members (bool): whether to return members from timeline_start + current: state at the end of the timeline + lazy_load_members: whether to return members from timeline_start or not. assumes that timeline_start has already been filtered to include only the members the client needs to know about. - - Returns: - dict """ event_id_to_key = { e: key @@ -1962,10 +2000,10 @@ def _calculate_state( ) } - c_ids = set(e for e in itervalues(current)) - ts_ids = set(e for e in itervalues(timeline_start)) - p_ids = set(e for e in itervalues(previous)) - tc_ids = set(e for e in itervalues(timeline_contains)) + c_ids = set(itervalues(current)) + ts_ids = set(itervalues(timeline_start)) + p_ids = set(itervalues(previous)) + tc_ids = set(itervalues(timeline_contains)) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, @@ -1987,15 +2025,16 @@ def _calculate_state( return {event_id_to_key[e]: e for e in state_ids} -class SyncResultBuilder(object): +@attr.s +class SyncResultBuilder: """Used to help build up a new SyncResult for a user Attributes: - sync_config (SyncConfig) - full_state (bool) - since_token (StreamToken) - now_token (StreamToken) - joined_room_ids (list[str]) + sync_config + full_state: The full_state flag as specified by user + since_token: The token supplied by user, or None. + now_token: The token to sync up to. + joined_room_ids: List of rooms the user is joined to # The following mirror the fields in a sync response presence (list) @@ -2003,61 +2042,45 @@ class SyncResultBuilder(object): joined (list[JoinedSyncResult]) invited (list[InvitedSyncResult]) archived (list[ArchivedSyncResult]) - device (list) groups (GroupsSyncResult|None) to_device (list) """ - def __init__( - self, sync_config, full_state, since_token, now_token, joined_room_ids - ): - """ - Args: - sync_config (SyncConfig) - full_state (bool): The full_state flag as specified by user - since_token (StreamToken): The token supplied by user, or None. - now_token (StreamToken): The token to sync up to. - joined_room_ids (list[str]): List of rooms the user is joined to - """ - self.sync_config = sync_config - self.full_state = full_state - self.since_token = since_token - self.now_token = now_token - self.joined_room_ids = joined_room_ids - - self.presence = [] - self.account_data = [] - self.joined = [] - self.invited = [] - self.archived = [] - self.device = [] - self.groups = None - self.to_device = [] + sync_config = attr.ib(type=SyncConfig) + full_state = attr.ib(type=bool) + since_token = attr.ib(type=Optional[StreamToken]) + now_token = attr.ib(type=StreamToken) + joined_room_ids = attr.ib(type=FrozenSet[str]) + + presence = attr.ib(type=List[JsonDict], default=attr.Factory(list)) + account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) + joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) + invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) + archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) + groups = attr.ib(type=Optional[GroupsSyncResult], default=None) + to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) +@attr.s class RoomSyncResultBuilder(object): """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. + + Attributes: + room_id + rtype: One of `"joined"` or `"archived"` + events: List of events to include in the room (more events may be added + when generating result). + newly_joined: If the user has newly joined the room + full_state: Whether the full state should be sent in result + since_token: Earliest point to return events from, or None + upto_token: Latest point to return events from. """ - def __init__( - self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token - ): - """ - Args: - room_id(str) - rtype(str): One of `"joined"` or `"archived"` - events(list[FrozenEvent]): List of events to include in the room - (more events may be added when generating result). - newly_joined(bool): If the user has newly joined the room - full_state(bool): Whether the full state should be sent in result - since_token(StreamToken): Earliest point to return events from, or None - upto_token(StreamToken): Latest point to return events from. - """ - self.room_id = room_id - self.rtype = rtype - self.events = events - self.newly_joined = newly_joined - self.full_state = full_state - self.since_token = since_token - self.upto_token = upto_token + room_id = attr.ib(type=str) + rtype = attr.ib(type=str) + events = attr.ib(type=Optional[List[EventBase]]) + newly_joined = attr.ib(type=bool) + full_state = attr.ib(type=bool) + since_token = attr.ib(type=Optional[StreamToken]) + upto_token = attr.ib(type=StreamToken) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index ca8ae9fb5b..c7bc14c623 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import List from twisted.internet import defer @@ -120,12 +121,12 @@ class TypingHandler(object): auth_user_id = auth_user.to_string() if not self.is_mine_id(target_user_id): - raise SynapseError(400, "User is not hosted on this Home Server") + raise SynapseError(400, "User is not hosted on this homeserver") if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user_id) + yield self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -150,12 +151,12 @@ class TypingHandler(object): auth_user_id = auth_user.to_string() if not self.is_mine_id(target_user_id): - raise SynapseError(400, "User is not hosted on this Home Server") + raise SynapseError(400, "User is not hosted on this homeserver") if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user_id) + yield self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) @@ -198,7 +199,7 @@ class TypingHandler(object): now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - for domain in set(get_domain_from_id(u) for u in users): + for domain in {get_domain_from_id(u) for u in users}: if domain != self.server_name: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( @@ -231,7 +232,7 @@ class TypingHandler(object): return users = yield self.state.get_current_users_in_room(room_id) - domains = set(get_domain_from_id(u) for u in users) + domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) @@ -257,7 +258,13 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates( + self, last_id: int, current_id: int, limit: int + ) -> List[dict]: + """Get up to `limit` typing updates between the given tokens, earliest + updates first. + """ + if last_id == current_id: return [] @@ -275,7 +282,7 @@ class TypingHandler(object): typing = self._room_typing[room_id] rows.append((serial, room_id, list(typing))) rows.sort() - return rows + return rows[:limit] def get_current_token(self): return self._latest_room_serial @@ -313,10 +320,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - return events, handler._latest_room_serial + return defer.succeed((events, handler._latest_room_serial)) def get_current_key(self): return self.get_typing_handler()._latest_room_serial - - def get_pagination_rows(self, user, pagination_config, key): - return [], pagination_config.from_key diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py new file mode 100644 index 0000000000..824f37f8f8 --- /dev/null +++ b/synapse/handlers/ui_auth/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module implements user-interactive auth verification. + +TODO: move more stuff out of AuthHandler in here. + +""" + +from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401 diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py new file mode 100644 index 0000000000..8b24a73319 --- /dev/null +++ b/synapse/handlers/ui_auth/checkers.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from canonicaljson import json + +from twisted.internet import defer +from twisted.web.client import PartialDownloadError + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.config.emailconfig import ThreepidBehaviour + +logger = logging.getLogger(__name__) + + +class UserInteractiveAuthChecker: + """Abstract base class for an interactive auth checker""" + + def __init__(self, hs): + pass + + def is_enabled(self): + """Check if the configuration of the homeserver allows this checker to work + + Returns: + bool: True if this login type is enabled. + """ + + def check_auth(self, authdict, clientip): + """Given the authentication dict from the client, attempt to check this step + + Args: + authdict (dict): authentication dictionary from the client + clientip (str): The IP address of the client. + + Raises: + SynapseError if authentication failed + + Returns: + Deferred: the result of authentication (to pass back to the client?) + """ + raise NotImplementedError() + + +class DummyAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.DUMMY + + def is_enabled(self): + return True + + def check_auth(self, authdict, clientip): + return defer.succeed(True) + + +class TermsAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.TERMS + + def is_enabled(self): + return True + + def check_auth(self, authdict, clientip): + return defer.succeed(True) + + +class RecaptchaAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.RECAPTCHA + + def __init__(self, hs): + super().__init__(hs) + self._enabled = bool(hs.config.recaptcha_private_key) + self._http_client = hs.get_proxied_http_client() + self._url = hs.config.recaptcha_siteverify_api + self._secret = hs.config.recaptcha_private_key + + def is_enabled(self): + return self._enabled + + @defer.inlineCallbacks + def check_auth(self, authdict, clientip): + try: + user_response = authdict["response"] + except KeyError: + # Client tried to provide captcha but didn't give the parameter: + # bad request. + raise LoginError( + 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED + ) + + logger.info( + "Submitting recaptcha response %s with remoteip %s", user_response, clientip + ) + + # TODO: get this from the homeserver rather than creating a new one for + # each request + try: + resp_body = yield self._http_client.post_urlencoded_get_json( + self._url, + args={ + "secret": self._secret, + "response": user_response, + "remoteip": clientip, + }, + ) + except PartialDownloadError as pde: + # Twisted is silly + data = pde.response + resp_body = json.loads(data) + + if "success" in resp_body: + # Note that we do NOT check the hostname here: we explicitly + # intend the CAPTCHA to be presented by whatever client the + # user is using, we just care that they have completed a CAPTCHA. + logger.info( + "%s reCAPTCHA from hostname %s", + "Successful" if resp_body["success"] else "Failed", + resp_body.get("hostname"), + ) + if resp_body["success"]: + return True + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + +class _BaseThreepidAuthChecker: + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + + async def _check_threepid(self, medium, authdict): + if "threepid_creds" not in authdict: + raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) + + threepid_creds = authdict["threepid_creds"] + + identity_handler = self.hs.get_handlers().identity_handler + + logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) + + # msisdns are currently always ThreepidBehaviour.REMOTE + if medium == "msisdn": + if not self.hs.config.account_threepid_delegate_msisdn: + raise SynapseError( + 400, "Phone number verification is not enabled on this homeserver" + ) + threepid = await identity_handler.threepid_from_creds( + self.hs.config.account_threepid_delegate_msisdn, threepid_creds + ) + elif medium == "email": + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + threepid = await identity_handler.threepid_from_creds( + self.hs.config.account_threepid_delegate_email, threepid_creds + ) + elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + threepid = None + row = await self.store.get_threepid_validation_session( + medium, + threepid_creds["client_secret"], + sid=threepid_creds["sid"], + validated=True, + ) + + if row: + threepid = { + "medium": row["medium"], + "address": row["address"], + "validated_at": row["validated_at"], + } + + # Valid threepid returned, delete from the db + await self.store.delete_threepid_session(threepid_creds["sid"]) + else: + raise SynapseError( + 400, "Email address verification is not enabled on this homeserver" + ) + else: + # this can't happen! + raise AssertionError("Unrecognized threepid medium: %s" % (medium,)) + + if not threepid: + raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + + if threepid["medium"] != medium: + raise LoginError( + 401, + "Expecting threepid of type '%s', got '%s'" + % (medium, threepid["medium"]), + errcode=Codes.UNAUTHORIZED, + ) + + threepid["threepid_creds"] = authdict["threepid_creds"] + + return threepid + + +class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): + AUTH_TYPE = LoginType.EMAIL_IDENTITY + + def __init__(self, hs): + UserInteractiveAuthChecker.__init__(self, hs) + _BaseThreepidAuthChecker.__init__(self, hs) + + def is_enabled(self): + return self.hs.config.threepid_behaviour_email in ( + ThreepidBehaviour.REMOTE, + ThreepidBehaviour.LOCAL, + ) + + def check_auth(self, authdict, clientip): + return defer.ensureDeferred(self._check_threepid("email", authdict)) + + +class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): + AUTH_TYPE = LoginType.MSISDN + + def __init__(self, hs): + UserInteractiveAuthChecker.__init__(self, hs) + _BaseThreepidAuthChecker.__init__(self, hs) + + def is_enabled(self): + return bool(self.hs.config.account_threepid_delegate_msisdn) + + def check_auth(self, authdict, clientip): + return defer.ensureDeferred(self._check_threepid("msisdn", authdict)) + + +INTERACTIVE_AUTH_CHECKERS = [ + DummyAuthChecker, + TermsAuthChecker, + RecaptchaAuthChecker, + EmailIdentityAuthChecker, + MsisdnAuthChecker, +] +"""A list of UserInteractiveAuthChecker classes""" diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index e53669e40d..12423b909a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -17,14 +17,11 @@ import logging from six import iteritems, iterkeys -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo -from synapse.types import get_localpart_from_id from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -52,6 +49,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory self.search_all_users = hs.config.user_directory_search_all_users + self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream self.pos = None @@ -65,7 +63,7 @@ class UserDirectoryHandler(StateDeltasHandler): # we start populating the user directory self.clock.call_later(0, self.notify_new_event) - def search_users(self, user_id, search_term, limit): + async def search_users(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -82,7 +80,16 @@ class UserDirectoryHandler(StateDeltasHandler): ] } """ - return self.store.search_user_dir(user_id, search_term, limit) + results = await self.store.search_user_dir(user_id, search_term, limit) + + # Remove any spammy users from the results. + results["results"] = [ + user + for user in results["results"] + if not self.spam_checker.check_username_for_spam(user) + ] + + return results def notify_new_event(self): """Called when there may be more deltas to process @@ -93,43 +100,39 @@ class UserDirectoryHandler(StateDeltasHandler): if self._is_processing: return - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False self._is_processing = True run_as_background_process("user_directory.notify_new_event", process) - @defer.inlineCallbacks - def handle_local_profile_change(self, user_id, profile): + async def handle_local_profile_change(self, user_id, profile): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - is_support = yield self.store.is_support_user(user_id) + is_support = await self.store.is_support_user(user_id) # Support users are for diagnostics and should not appear in the user directory. if not is_support: - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - @defer.inlineCallbacks - def handle_user_deactivated(self, user_id): + async def handle_user_deactivated(self, user_id): """Called when a user ID is deactivated """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_user_directory_stream_pos() + self.pos = await self.store.get_user_directory_stream_pos() # If still None then the initial background update hasn't happened yet if self.pos is None: @@ -138,24 +141,30 @@ class UserDirectoryHandler(StateDeltasHandler): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "user_dir_delta"): - deltas = yield self.store.get_current_state_deltas(self.pos) - if not deltas: + room_max_stream_ordering = self.store.get_room_max_stream_ordering() + if self.pos == room_max_stream_ordering: return - logger.info("Handling %d state deltas", len(deltas)) - yield self._handle_deltas(deltas) + logger.debug( + "Processing user stats %s->%s", self.pos, room_max_stream_ordering + ) + max_pos, deltas = await self.store.get_current_state_deltas( + self.pos, room_max_stream_ordering + ) + + logger.debug("Handling %d state deltas", len(deltas)) + await self._handle_deltas(deltas) - self.pos = deltas[-1]["stream_id"] + self.pos = max_pos # Expose current event processing position to prometheus synapse.metrics.event_processing_positions.labels("user_dir").set( - self.pos + max_pos ) - yield self.store.update_user_directory_stream_pos(self.pos) + await self.store.update_user_directory_stream_pos(max_pos) - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process """ for delta in deltas: @@ -170,11 +179,11 @@ class UserDirectoryHandler(StateDeltasHandler): # For join rule and visibility changes we need to check if the room # may have become public or not and add/remove the users in said room if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules): - yield self._handle_room_publicity_change( + await self._handle_room_publicity_change( room_id, prev_event_id, event_id, typ ) elif typ == EventTypes.Member: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="membership", @@ -184,48 +193,49 @@ class UserDirectoryHandler(StateDeltasHandler): if change is False: # Need to check if the server left the room entirely, if so # we might need to remove all the users in that room - is_in_room = yield self.store.is_host_joined( + is_in_room = await self.store.is_host_joined( room_id, self.server_name ) if not is_in_room: - logger.info("Server left room: %r", room_id) + logger.debug("Server left room: %r", room_id) # Fetch all the users that we marked as being in user # directory due to being in the room and then check if # need to remove those users or not - user_ids = yield self.store.get_users_in_dir_due_to_room( + user_ids = await self.store.get_users_in_dir_due_to_room( room_id ) for user_id in user_ids: - yield self._handle_remove_user(room_id, user_id) + await self._handle_remove_user(room_id, user_id) return else: logger.debug("Server is still in room: %r", room_id) - is_support = yield self.store.is_support_user(state_key) + is_support = await self.store.is_support_user(state_key) if not is_support: if change is None: # Handle any profile changes - yield self._handle_profile_change( + await self._handle_profile_change( state_key, room_id, prev_event_id, event_id ) continue if change: # The user joined - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), ) - yield self._handle_new_user(room_id, state_key, profile) + await self._handle_new_user(room_id, state_key, profile) else: # The user left - yield self._handle_remove_user(room_id, state_key) + await self._handle_remove_user(room_id, state_key) else: logger.debug("Ignoring irrelevant type: %r", typ) - @defer.inlineCallbacks - def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ): + async def _handle_room_publicity_change( + self, room_id, prev_event_id, event_id, typ + ): """Handle a room having potentially changed from/to world_readable/publically joinable. @@ -238,14 +248,14 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Handling change for %s: %s", typ, room_id) if typ == EventTypes.RoomHistoryVisibility: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="history_visibility", public_value="world_readable", ) elif typ == EventTypes.JoinRules: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="join_rule", @@ -261,7 +271,7 @@ class UserDirectoryHandler(StateDeltasHandler): # There's been a change to or from being world readable. - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) @@ -276,11 +286,11 @@ class UserDirectoryHandler(StateDeltasHandler): # ignore the change return - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. for user_id in iterkeys(users_with_profile): - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. # NOTE: this is not the most efficient method, as handle_new_user sets @@ -289,26 +299,9 @@ class UserDirectoryHandler(StateDeltasHandler): # being added multiple times. The batching upserts shouldn't make this # too bad, though. for user_id, profile in iteritems(users_with_profile): - yield self._handle_new_user(room_id, user_id, profile) - - @defer.inlineCallbacks - def _handle_local_user(self, user_id): - """Adds a new local roomless user into the user_directory_search table. - Used to populate up the user index when we have an - user_directory_search_all_users specified. - """ - logger.debug("Adding new local user to dir, %r", user_id) - - profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) - - row = yield self.store.get_user_in_directory(user_id) - if not row: - yield self.store.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) + await self._handle_new_user(room_id, user_id, profile) - @defer.inlineCallbacks - def _handle_new_user(self, room_id, user_id, profile): + async def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: @@ -317,18 +310,18 @@ class UserDirectoryHandler(StateDeltasHandler): """ logger.debug("Adding new user to dir, %r", user_id) - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) # Now we update users who share rooms with users. - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) if is_public: - yield self.store.add_users_in_public_rooms(room_id, (user_id,)) + await self.store.add_users_in_public_rooms(room_id, (user_id,)) else: to_insert = set() @@ -359,10 +352,9 @@ class UserDirectoryHandler(StateDeltasHandler): to_insert.add((other_user_id, user_id)) if to_insert: - yield self.store.add_users_who_share_private_room(room_id, to_insert) + await self.store.add_users_who_share_private_room(room_id, to_insert) - @defer.inlineCallbacks - def _handle_remove_user(self, room_id, user_id): + async def _handle_remove_user(self, room_id, user_id): """Called when we might need to remove user from directory Args: @@ -372,24 +364,23 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Removing user %r", user_id) # Remove user from sharing tables - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Are they still in any rooms? If not, remove them entirely. - rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id) + rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) if len(rooms_user_is_in) == 0: - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): + async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): """Check member event changes for any profile changes and update the database if there are. """ if not prev_event_id or not event_id: return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) - event = yield self.store.get_event(event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not prev_event or not event: return @@ -404,4 +395,4 @@ class UserDirectoryHandler(StateDeltasHandler): new_avatar = event.content.get("avatar_url") if prev_name != new_name or prev_avatar != new_avatar: - yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) + await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 3acf772cd1..3880ce0d94 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -42,11 +42,13 @@ def cancelled_to_request_timed_out_error(value, timeout): ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") +CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$") def redact_uri(uri): - """Strips access tokens from the uri replaces with <redacted>""" - return ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri) + """Strips sensitive information from the uri replaces with <redacted>""" + uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri) + return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri) class QuieterFileBodyProducer(FileBodyProducer): diff --git a/synapse/http/client.py b/synapse/http/client.py index 51765ae3c0..3cef747a4d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,10 +45,10 @@ from synapse.http import ( cancelled_to_request_timed_out_error, redact_uri, ) +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred -from synapse.util.caches import CACHE_SIZE_FACTOR logger = logging.getLogger(__name__) @@ -183,7 +183,15 @@ class SimpleHttpClient(object): using HTTP in Matrix """ - def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + hs, + treq_args={}, + ip_whitelist=None, + ip_blacklist=None, + http_proxy=None, + https_proxy=None, + ): """ Args: hs (synapse.server.HomeServer) @@ -192,6 +200,8 @@ class SimpleHttpClient(object): we may not request. ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. + http_proxy (bytes): proxy server to use for http connections. host[:port] + https_proxy (bytes): proxy server to use for https connections. host[:port] """ self.hs = hs @@ -230,17 +240,19 @@ class SimpleHttpClient(object): # tends to do so in batches, so we need to allow the pool to keep # lots of idle connections around. pool = HTTPConnectionPool(self.reactor) - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + # XXX: The justification for using the cache factor here is that larger instances + # will need both more cache and more connections. + # Still, this should probably be a separate dial + pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 - # The default context factory in Twisted 14.0.0 (which we require) is - # BrowserLikePolicyForHTTPS which will do regular cert validation - # 'like a browser' - self.agent = Agent( + self.agent = ProxyAgent( self.reactor, connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, + http_proxy=http_proxy, + https_proxy=https_proxy, ) if self._ip_blacklist: @@ -327,7 +339,7 @@ class SimpleHttpClient(object): Args: uri (str): args (dict[str, str|List[str]]): query params - headers (dict[str, List[str]]|None): If not None, a map from + headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: @@ -349,6 +361,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) @@ -371,7 +384,7 @@ class SimpleHttpClient(object): Args: uri (str): post_json (object): - headers (dict[str, List[str]]|None): If not None, a map from + headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: @@ -389,6 +402,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/json"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) @@ -414,7 +428,7 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. - headers (dict[str, List[str]]|None): If not None, a map from + headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the @@ -424,6 +438,10 @@ class SimpleHttpClient(object): ValueError: if the response was not JSON """ + actual_headers = {b"Accept": [b"application/json"]} + if headers: + actual_headers.update(headers) + body = yield self.get_raw(uri, args, headers=headers) return json.loads(body) @@ -438,7 +456,7 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. - headers (dict[str, List[str]]|None): If not None, a map from + headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the @@ -457,6 +475,7 @@ class SimpleHttpClient(object): actual_headers = { b"Content-Type": [b"application/json"], b"User-Agent": [self.user_agent], + b"Accept": [b"application/json"], } if headers: actual_headers.update(headers) @@ -482,7 +501,7 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. - headers (dict[str, List[str]]|None): If not None, a map from + headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: Deferred: Succeeds when we get *any* 2xx HTTP response, with the @@ -516,7 +535,7 @@ class SimpleHttpClient(object): Args: url (str): The URL to GET output_stream (file): File to write the response body to. - headers (dict[str, List[str]]|None): If not None, a map from + headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from header name to a list of values for that header Returns: A (int,dict,string,int) tuple of the file length, dict of the response @@ -535,7 +554,7 @@ class SimpleHttpClient(object): b"Content-Length" in resp_headers and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, "Requested file is too large > %r bytes" % (self.max_size,), @@ -543,7 +562,7 @@ class SimpleHttpClient(object): ) if response.code > 299: - logger.warn("Got %d when downloading %s" % (response.code, url)) + logger.warning("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py new file mode 100644 index 0000000000..be7b2ceb8e --- /dev/null +++ b/synapse/http/connectproxyclient.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from zope.interface import implementer + +from twisted.internet import defer, protocol +from twisted.internet.error import ConnectError +from twisted.internet.interfaces import IStreamClientEndpoint +from twisted.internet.protocol import connectionDone +from twisted.web import http + +logger = logging.getLogger(__name__) + + +class ProxyConnectError(ConnectError): + pass + + +@implementer(IStreamClientEndpoint) +class HTTPConnectProxyEndpoint(object): + """An Endpoint implementation which will send a CONNECT request to an http proxy + + Wraps an existing HostnameEndpoint for the proxy. + + When we get the connect() request from the connection pool (via the TLS wrapper), + we'll first connect to the proxy endpoint with a ProtocolFactory which will make the + CONNECT request. Once that completes, we invoke the protocolFactory which was passed + in. + + Args: + reactor: the Twisted reactor to use for the connection + proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the + proxy + host (bytes): hostname that we want to CONNECT to + port (int): port that we want to connect to + """ + + def __init__(self, reactor, proxy_endpoint, host, port): + self._reactor = reactor + self._proxy_endpoint = proxy_endpoint + self._host = host + self._port = port + + def __repr__(self): + return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) + + def connect(self, protocolFactory): + f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + d = self._proxy_endpoint.connect(f) + # once the tcp socket connects successfully, we need to wait for the + # CONNECT to complete. + d.addCallback(lambda conn: f.on_connection) + return d + + +class HTTPProxiedClientFactory(protocol.ClientFactory): + """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect. + + Once the CONNECT completes, invokes the original ClientFactory to build the + HTTP Protocol object and run the rest of the connection. + + Args: + dst_host (bytes): hostname that we want to CONNECT to + dst_port (int): port that we want to connect to + wrapped_factory (protocol.ClientFactory): The original Factory + """ + + def __init__(self, dst_host, dst_port, wrapped_factory): + self.dst_host = dst_host + self.dst_port = dst_port + self.wrapped_factory = wrapped_factory + self.on_connection = defer.Deferred() + + def startedConnecting(self, connector): + return self.wrapped_factory.startedConnecting(connector) + + def buildProtocol(self, addr): + wrapped_protocol = self.wrapped_factory.buildProtocol(addr) + + return HTTPConnectProtocol( + self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + ) + + def clientConnectionFailed(self, connector, reason): + logger.debug("Connection to proxy failed: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + logger.debug("Connection to proxy lost: %s", reason) + if not self.on_connection.called: + self.on_connection.errback(reason) + return self.wrapped_factory.clientConnectionLost(connector, reason) + + +class HTTPConnectProtocol(protocol.Protocol): + """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect + + Args: + host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + to put in the CONNECT request + + port (int): The original HTTP(s) port to put in the CONNECT request + + wrapped_protocol (interfaces.IProtocol): the original protocol (probably + HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + + connected_deferred (Deferred): a Deferred which will be callbacked with + wrapped_protocol when the CONNECT completes + """ + + def __init__(self, host, port, wrapped_protocol, connected_deferred): + self.host = host + self.port = port + self.wrapped_protocol = wrapped_protocol + self.connected_deferred = connected_deferred + self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.http_setup_client.on_connected.addCallback(self.proxyConnected) + + def connectionMade(self): + self.http_setup_client.makeConnection(self.transport) + + def connectionLost(self, reason=connectionDone): + if self.wrapped_protocol.connected: + self.wrapped_protocol.connectionLost(reason) + + self.http_setup_client.connectionLost(reason) + + if not self.connected_deferred.called: + self.connected_deferred.errback(reason) + + def proxyConnected(self, _): + self.wrapped_protocol.makeConnection(self.transport) + + self.connected_deferred.callback(self.wrapped_protocol) + + # Get any pending data from the http buf and forward it to the original protocol + buf = self.http_setup_client.clearLineBuffer() + if buf: + self.wrapped_protocol.dataReceived(buf) + + def dataReceived(self, data): + # if we've set up the HTTP protocol, we can send the data there + if self.wrapped_protocol.connected: + return self.wrapped_protocol.dataReceived(data) + + # otherwise, we must still be setting up the connection: send the data to the + # setup client + return self.http_setup_client.dataReceived(data) + + +class HTTPConnectSetupClient(http.HTTPClient): + """HTTPClient protocol to send a CONNECT message for proxies and read the response. + + Args: + host (bytes): The hostname to send in the CONNECT message + port (int): The port to send in the CONNECT message + """ + + def __init__(self, host, port): + self.host = host + self.port = port + self.on_connected = defer.Deferred() + + def connectionMade(self): + logger.debug("Connected to proxy, sending CONNECT") + self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + self.endHeaders() + + def handleStatus(self, version, status, message): + logger.debug("Got Status: %s %s %s", status, message, version) + if status != b"200": + raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + + def handleEndHeaders(self): + logger.debug("End Headers") + self.on_connected.callback(None) + + def handleResponse(self, body): + pass diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 647d26dc56..f5f917f5ae 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -45,7 +45,7 @@ class MatrixFederationAgent(object): Args: reactor (IReactor): twisted reactor to use for underlying requests - tls_client_options_factory (ClientTLSOptionsFactory|None): + tls_client_options_factory (FederationPolicyForHTTPS|None): factory to use for fetching client tls options, or none to disable TLS. _srv_resolver (SrvResolver|None): diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 3fe4ffb9e5..021b233a7d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -148,7 +148,7 @@ class SrvResolver(object): # Try something in the cache, else rereaise cache_entry = self._cache.get(service_name, None) if cache_entry: - logger.warn( + logger.warning( "Failed to resolve %r, falling back to cache. %r", service_name, e ) return list(cache_entry) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3f7c93ffcb..2d47b9ea00 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,7 +19,7 @@ import random import sys from io import BytesIO -from six import PY3, raise_from, string_types +from six import raise_from, string_types from six.moves import urllib import attr @@ -70,11 +70,7 @@ incoming_responses_counter = Counter( MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 - -if PY3: - MAXINT = sys.maxsize -else: - MAXINT = sys.maxint +MAXINT = sys.maxsize _next_id = 1 @@ -148,8 +144,13 @@ def _handle_json_response(reactor, timeout_sec, request, response): d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = yield make_deferred_yieldable(d) + except TimeoutError as e: + logger.warning( + "{%s} [%s] Timed out reading response", request.txn_id, request.destination, + ) + raise RequestSendFailed(e, can_retry=True) from e except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, @@ -408,6 +409,8 @@ class MatrixFederationHttpClient(object): _sec_timeout, ) + outgoing_requests_counter.labels(request.method).inc() + try: with Measure(self.clock, "outbound_request"): # we don't want all the fancy cookie and redirect handling @@ -426,25 +429,37 @@ class MatrixFederationHttpClient(object): ) response = yield request_deferred + except TimeoutError as e: + raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) except Exception as e: logger.info("Failed to send request: %s", e) raise_from(RequestSendFailed(e, can_retry=True), e) - logger.info( - "{%s} [%s] Got response headers: %d %s", - request.txn_id, - request.destination, - response.code, - response.phrase.decode("ascii", errors="replace"), - ) + incoming_responses_counter.labels( + request.method, response.code + ).inc() set_tag(tags.HTTP_STATUS_CODE, response.code) if 200 <= response.code < 300: + logger.debug( + "{%s} [%s] Got response headers: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode("ascii", errors="replace"), + ) pass else: + logger.info( + "{%s} [%s] Got response headers: %d %s", + request.txn_id, + request.destination, + response.code, + response.phrase.decode("ascii", errors="replace"), + ) # :'( # Update transactions table? d = treq.content(response) @@ -457,7 +472,7 @@ class MatrixFederationHttpClient(object): except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. - logger.warn( + logger.warning( "{%s} [%s] Failed to get error response: %s %s: %s", request.txn_id, request.destination, @@ -478,7 +493,7 @@ class MatrixFederationHttpClient(object): break except RequestSendFailed as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -513,7 +528,7 @@ class MatrixFederationHttpClient(object): raise except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -530,7 +545,7 @@ class MatrixFederationHttpClient(object): """ Builds the Authorization headers for a federation request Args: - destination (bytes|None): The desination home server of the request. + destination (bytes|None): The desination homeserver of the request. May be None if the destination is an identity server, in which case destination_is must be non-None. method (bytes): The HTTP method of the request @@ -889,7 +904,7 @@ class MatrixFederationHttpClient(object): d.addTimeout(self.default_timeout, self.reactor) length = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py new file mode 100644 index 0000000000..332da02a8d --- /dev/null +++ b/synapse/http/proxyagent.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.python.failure import Failure +from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase +from twisted.web.error import SchemeNotSupported +from twisted.web.iweb import IAgent + +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint + +logger = logging.getLogger(__name__) + +_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") + + +@implementer(IAgent) +class ProxyAgent(_AgentBase): + """An Agent implementation which will use an HTTP proxy if one was requested + + Args: + reactor: twisted reactor to place outgoing + connections. + + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the + verification parameters of OpenSSL. The default is to use a + `BrowserLikePolicyForHTTPS`, so unless you have special + requirements you can leave this as-is. + + connectTimeout (float): The amount of time that this Agent will wait + for the peer to accept a connection. + + bindAddress (bytes): The local address for client sockets to bind to. + + pool (HTTPConnectionPool|None): connection pool to be used. If None, a + non-persistent pool instance will be created. + """ + + def __init__( + self, + reactor, + contextFactory=BrowserLikePolicyForHTTPS(), + connectTimeout=None, + bindAddress=None, + pool=None, + http_proxy=None, + https_proxy=None, + ): + _AgentBase.__init__(self, reactor, pool) + + self._endpoint_kwargs = {} + if connectTimeout is not None: + self._endpoint_kwargs["timeout"] = connectTimeout + if bindAddress is not None: + self._endpoint_kwargs["bindAddress"] = bindAddress + + self.http_proxy_endpoint = _http_proxy_endpoint( + http_proxy, reactor, **self._endpoint_kwargs + ) + + self.https_proxy_endpoint = _http_proxy_endpoint( + https_proxy, reactor, **self._endpoint_kwargs + ) + + self._policy_for_https = contextFactory + self._reactor = reactor + + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Issue a request to the server indicated by the given uri. + + Supports `http` and `https` schemes. + + An existing connection from the connection pool may be used or a new one may be + created. + + See also: twisted.web.iweb.IAgent.request + + Args: + method (bytes): The request method to use, such as `GET`, `POST`, etc + + uri (bytes): The location of the resource to request. + + headers (Headers|None): Extra headers to send with the request + + bodyProducer (IBodyProducer|None): An object which can generate bytes to + make up the body of this request (for example, the properly encoded + contents of a file for a file upload). Or, None if the request is to + have no body. + + Returns: + Deferred[IResponse]: completes when the header of the response has + been received (regardless of the response status code). + """ + uri = uri.strip() + if not _VALID_URI.match(uri): + raise ValueError("Invalid URI {!r}".format(uri)) + + parsed_uri = URI.fromBytes(uri) + pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) + request_path = parsed_uri.originForm + + if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + # Cache *all* connections under the same key, since we are only + # connecting to a single destination, the proxy: + pool_key = ("http-proxy", self.http_proxy_endpoint) + endpoint = self.http_proxy_endpoint + request_path = uri + elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self.https_proxy_endpoint, + parsed_uri.host, + parsed_uri.port, + ) + else: + # not using a proxy + endpoint = HostnameEndpoint( + self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs + ) + + logger.debug("Requesting %s via %s", uri, endpoint) + + if parsed_uri.scheme == b"https": + tls_connection_creator = self._policy_for_https.creatorForNetloc( + parsed_uri.host, parsed_uri.port + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + elif parsed_uri.scheme == b"http": + pass + else: + return defer.fail( + Failure( + SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,)) + ) + ) + + return self._requestWithEndpoint( + pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path + ) + + +def _http_proxy_endpoint(proxy, reactor, **kwargs): + """Parses an http proxy setting and returns an endpoint for the proxy + + Args: + proxy (bytes|None): the proxy setting + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint + + Returns: + interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy, + or None + """ + if proxy is None: + return None + + # currently we only support hostname:port. Some apps also support + # protocol://<host>[:port], which allows a way of requiring a TLS connection to the + # proxy. + + host, port = parse_host_port(proxy, default_port=1080) + return HostnameEndpoint(reactor, host, port, **kwargs) + + +def parse_host_port(hostport, default_port=None): + # could have sworn we had one of these somewhere else... + if b":" in hostport: + host, port = hostport.rsplit(b":", 1) + try: + port = int(port) + return host, port + except ValueError: + # the thing after the : wasn't a valid port; presumably this is an + # IPv6 address. + pass + + return hostport, default_port diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 46af27c8f6..b58ae3d9db 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -19,7 +19,7 @@ import threading from prometheus_client.core import Counter, Histogram -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.metrics import LaterGauge logger = logging.getLogger(__name__) @@ -148,7 +148,7 @@ LaterGauge( class RequestMetrics(object): def start(self, time_sec, name, method): self.start = time_sec - self.start_context = LoggingContext.current_context() + self.start_context = current_context() self.name = name self.method = method @@ -163,14 +163,14 @@ class RequestMetrics(object): with _in_flight_requests_lock: _in_flight_requests.discard(self) - context = LoggingContext.current_context() + context = current_context() tag = "" if context: tag = context.tag if context != self.start_context: - logger.warn( + logger.warning( "Context have unexpectedly changed %r, %r", context, self.start_context, diff --git a/synapse/http/server.py b/synapse/http/server.py index cb9158fe1b..2487a72171 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,20 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cgi import collections +import html import http.client import logging import types import urllib from io import BytesIO +from typing import Awaitable, Callable, TypeVar, Union +import jinja2 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from twisted.internet import defer from twisted.python import failure from twisted.web import resource -from twisted.web.server import NOT_DONE_YET +from twisted.web.server import NOT_DONE_YET, Request from twisted.web.static import NoRangeStaticProducer from twisted.web.util import redirectTo @@ -36,9 +38,11 @@ import synapse.metrics from synapse.api.errors import ( CodeMessageException, Codes, + RedirectException, SynapseError, UnrecognizedRequestError, ) +from synapse.http.site import SynapseRequest from synapse.logging.context import preserve_fn from synapse.logging.opentracing import trace_servlet from synapse.util.caches import intern_dict @@ -129,7 +133,12 @@ def wrap_json_request_handler(h): return wrap_async_request_handler(wrapped_request_handler) -def wrap_html_request_handler(h): +TV = TypeVar("TV") + + +def wrap_html_request_handler( + h: Callable[[TV, SynapseRequest], Awaitable] +) -> Callable[[TV, SynapseRequest], Awaitable[None]]: """Wraps a request handler method with exception handling. Also does the wrapping with request.processing as per wrap_async_request_handler. @@ -140,27 +149,37 @@ def wrap_html_request_handler(h): async def wrapped_request_handler(self, request): try: - return await h(self, request) + await h(self, request) except Exception: f = failure.Failure() - return _return_html_error(f, request) + return_html_error(f, request, HTML_ERROR_TEMPLATE) return wrap_async_request_handler(wrapped_request_handler) -def _return_html_error(f, request): - """Sends an HTML error page corresponding to the given failure +def return_html_error( + f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template], +) -> None: + """Sends an HTML error page corresponding to the given failure. + + Handles RedirectException and other CodeMessageExceptions (such as SynapseError) Args: - f (twisted.python.failure.Failure): - request (twisted.web.iweb.IRequest): + f: the error to report + request: the failing request + error_template: the HTML template. Can be either a string (with `{code}`, + `{msg}` placeholders), or a jinja2 template """ if f.check(CodeMessageException): cme = f.value code = cme.code msg = cme.msg - if isinstance(cme, SynapseError): + if isinstance(cme, RedirectException): + logger.info("%s redirect to %s", request, cme.location) + request.setHeader(b"location", cme.location) + request.cookies.extend(cme.cookies) + elif isinstance(cme, SynapseError): logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( @@ -169,7 +188,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http.client.INTERNAL_SERVER_ERROR + code = http.HTTPStatus.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( @@ -178,11 +197,16 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") + if isinstance(error_template, str): + body = error_template.format(code=code, msg=html.escape(msg)) + else: + body = error_template.render(code=code, msg=msg) + + body_bytes = body.encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%i" % (len(body),)) - request.write(body) + request.setHeader(b"Content-Length", b"%i" % (len(body_bytes),)) + request.write(body_bytes) finish_request(request) @@ -345,13 +369,12 @@ class JsonResource(HttpServer, resource.Resource): register_paths, so will return (possibly via Deferred) either None, or a tuple of (http code, response body). """ - if request.method == b"OPTIONS": - return _options_handler, "options_request_handler", {} + request_path = request.path.decode("ascii") # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path.decode("ascii")) + m = path_entry.pattern.match(request_path) if m: # We found a match! return path_entry.callback, path_entry.servlet_classname, m.groupdict() @@ -388,7 +411,7 @@ class DirectServeResource(resource.Resource): if not callback: return super().render(request) - resp = callback(request) + resp = trace_servlet(self.__class__.__name__)(callback)(request) # If it's a coroutine, turn it into a Deferred if isinstance(resp, types.CoroutineType): @@ -441,6 +464,26 @@ class RootRedirect(resource.Resource): return resource.Resource.getChild(self, name, request) +class OptionsResource(resource.Resource): + """Responds to OPTION requests for itself and all children.""" + + def render_OPTIONS(self, request): + code, response_json_object = _options_handler(request) + + return respond_with_json( + request, code, response_json_object, send_cors=True, canonical_json=False, + ) + + def getChildWithDefault(self, path, request): + if request.method == b"OPTIONS": + return self # select ourselves as the child to render + return resource.Resource.getChildWithDefault(self, path, request) + + +class RootOptionsRedirectResource(OptionsResource, RootRedirect): + pass + + def respond_with_json( request, code, @@ -454,7 +497,7 @@ def respond_with_json( # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: - logger.warn( + logger.warning( "Not sending response to request %s, already disconnected.", request ) return diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 274c1a6a87..13fcb408a6 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -96,7 +96,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): return {b"true": True, b"false": False}[args[name][0]] except Exception: message = ( - "Boolean query parameter %r must be one of" " ['true', 'false']" + "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) raise SynapseError(400, message) else: @@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False): try: content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: - logger.warn("Unable to decode UTF-8") + logger.warning("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) try: content = json.loads(content_unicode) except Exception as e: - logger.warn("Unable to parse JSON: %s", e) + logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content diff --git a/synapse/http/site.py b/synapse/http/site.py index df5274c177..167293c46d 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,9 @@ import contextlib import logging import time +from typing import Optional +from twisted.python.failure import Failure from twisted.web.server import Request, Site from synapse.http import redact_uri @@ -44,18 +46,18 @@ class SynapseRequest(Request): request even after the client has disconnected. Attributes: - logcontext(LoggingContext) : the log context for this request + logcontext: the log context for this request """ - def __init__(self, site, channel, *args, **kw): + def __init__(self, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) - self.site = site + self.site = channel.site self._channel = channel # this is used by the tests self.authenticated_entity = None - self.start_time = 0 + self.start_time = 0.0 # we can't yet create the logcontext, as we don't know the method. - self.logcontext = None + self.logcontext = None # type: Optional[LoggingContext] global _next_request_seq self.request_seq = _next_request_seq @@ -88,7 +90,7 @@ class SynapseRequest(Request): def get_redacted_uri(self): uri = self.uri if isinstance(uri, bytes): - uri = self.uri.decode("ascii") + uri = self.uri.decode("ascii", errors="replace") return redact_uri(uri) def get_method(self): @@ -181,6 +183,7 @@ class SynapseRequest(Request): self.finish_time = time.time() Request.finish(self) if not self._is_processing: + assert self.logcontext is not None with PreserveLoggingContext(self.logcontext): self._finished_processing() @@ -190,16 +193,28 @@ class SynapseRequest(Request): Overrides twisted.web.server.Request.connectionLost to record the finish time and do logging. """ + # There is a bug in Twisted where reason is not wrapped in a Failure object + # Detect this and wrap it manually as a workaround + # More information: https://github.com/matrix-org/synapse/issues/7441 + if not isinstance(reason, Failure): + reason = Failure(reason) + self.finish_time = time.time() Request.connectionLost(self, reason) + if self.logcontext is None: + logger.info( + "Connection from %s lost before request headers were read", self.client + ) + return + # we only get here if the connection to the client drops before we send # the response. # # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warn( + logger.warning( "Error processing request %r: %s %s", self, reason.type, reason.value ) @@ -225,7 +240,7 @@ class SynapseRequest(Request): self.start_time, name=servlet_name, method=self.get_method() ) - self.site.access_logger.info( + self.site.access_logger.debug( "%s - %s - Received request: %s %s", self.getClientIP(), self.site.site_tag, @@ -236,13 +251,7 @@ class SynapseRequest(Request): def _finished_processing(self): """Log the completion of this request and update the metrics """ - - if self.logcontext is None: - # this can happen if the connection closed before we read the - # headers (so render was never called). In that case we'll already - # have logged a warning, so just bail out. - return - + assert self.logcontext is not None usage = self.logcontext.get_resource_usage() if self._processing_finished_time is None: @@ -305,7 +314,7 @@ class SynapseRequest(Request): try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: - logger.warn("Failed to stop metrics: %r", e) + logger.warning("Failed to stop metrics: %r", e) class XForwardedForRequest(SynapseRequest): @@ -331,18 +340,6 @@ class XForwardedForRequest(SynapseRequest): ) -class SynapseRequestFactory(object): - def __init__(self, site, x_forwarded_for): - self.site = site - self.x_forwarded_for = x_forwarded_for - - def __call__(self, *args, **kwargs): - if self.x_forwarded_for: - return XForwardedForRequest(self.site, *args, **kwargs) - else: - return SynapseRequest(self.site, *args, **kwargs) - - class SynapseSite(Site): """ Subclass of a twisted http Site that does access logging with python's @@ -364,7 +361,7 @@ class SynapseSite(Site): self.site_tag = site_tag proxied = config.get("x_forwarded", False) - self.requestFactory = SynapseRequestFactory(self, proxied) + self.requestFactory = XForwardedForRequest if proxied else SynapseRequest self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 3220e985a9..7372450b45 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -42,7 +42,7 @@ from synapse.logging._terse_json import ( TerseJSONToConsoleLogObserver, TerseJSONToTCPLogObserver, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context def stdlib_log_level_to_twisted(level: str) -> LogLevel: @@ -86,7 +86,7 @@ class LogContextObserver(object): ].startswith("Timing out client"): return - context = LoggingContext.current_context() + context = current_context() # Copy the context information to the log event. if context is not None: @@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} def parse_drain_configs( - drains: dict + drains: dict, ) -> typing.Generator[DrainConfiguration, None, None]: """ Parse the drain configurations. @@ -261,6 +261,18 @@ def parse_drain_configs( ) +class StoppableLogPublisher(LogPublisher): + """ + A log publisher that can tell its observers to shut down any external + communications. + """ + + def stop(self): + for obs in self._observers: + if hasattr(obs, "stop"): + obs.stop() + + def setup_structured_logging( hs, config, @@ -336,7 +348,7 @@ def setup_structured_logging( # We should never get here, but, just in case, throw an error. raise ConfigError("%s drain type cannot be configured" % (observer.type,)) - publisher = LogPublisher(*observers) + publisher = StoppableLogPublisher(*observers) log_filter = LogLevelFilterPredicate() for namespace, namespace_config in log_config.get( diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 0ebbde06f2..c0b9384189 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -17,25 +17,29 @@ Log formatters that output terse JSON. """ +import json import sys +import traceback from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing import IO +from typing import IO, Optional import attr -from simplejson import dumps from zope.interface import implementer from twisted.application.internet import ClientService +from twisted.internet.defer import Deferred from twisted.internet.endpoints import ( HostnameEndpoint, TCP4ClientEndpoint, TCP6ClientEndpoint, ) +from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.protocol import Factory, Protocol from twisted.logger import FileLogObserver, ILogObserver, Logger -from twisted.python.failure import Failure + +_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) def flatten_event(event: dict, metadata: dict, include_time: bool = False): @@ -141,19 +145,57 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb def formatEvent(_event: dict) -> str: flattened = flatten_event(_event, metadata) - return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n" + return _encoder.encode(flattened) + "\n" return FileLogObserver(outFile, formatEvent) @attr.s +@implementer(IPushProducer) +class LogProducer(object): + """ + An IPushProducer that writes logs from its buffer to its transport when it + is resumed. + + Args: + buffer: Log buffer to read logs from. + transport: Transport to write to. + """ + + transport = attr.ib(type=ITransport) + _buffer = attr.ib(type=deque) + _paused = attr.ib(default=False, type=bool, init=False) + + def pauseProducing(self): + self._paused = True + + def stopProducing(self): + self._paused = True + self._buffer = deque() + + def resumeProducing(self): + self._paused = False + + while self._paused is False and (self._buffer and self.transport.connected): + try: + event = self._buffer.popleft() + self.transport.write(_encoder.encode(event).encode("utf8")) + self.transport.write(b"\n") + except Exception: + # Something has gone wrong writing to the transport -- log it + # and break out of the while. + traceback.print_exc(file=sys.__stderr__) + break + + +@attr.s @implementer(ILogObserver) class TerseJSONToTCPLogObserver(object): """ An IObserver that writes JSON logs to a TCP target. Args: - hs (HomeServer): The Homeserver that is being logged for. + hs (HomeServer): The homeserver that is being logged for. host: The host of the logging target. port: The logging target's port. metadata: Metadata to be added to each log entry. @@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object): metadata = attr.ib(type=dict) maximum_buffer = attr.ib(type=int) _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _writer = attr.ib(default=None) + _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) _logger = attr.ib(default=attr.Factory(Logger)) + _producer = attr.ib(default=None, type=Optional[LogProducer]) def start(self) -> None: @@ -187,38 +230,44 @@ class TerseJSONToTCPLogObserver(object): factory = Factory.forProtocol(Protocol) self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) self._service.startService() + self._connect() - def _write_loop(self) -> None: + def stop(self): + self._service.stopService() + + def _connect(self) -> None: """ - Implement the write loop. + Triggers an attempt to connect then write to the remote if not already writing. """ - if self._writer: + if self._connection_waiter: return - self._writer = self._service.whenConnected() + self._connection_waiter = self._service.whenConnected(failAfterFailures=1) + + @self._connection_waiter.addErrback + def fail(r): + r.printTraceback(file=sys.__stderr__) + self._connection_waiter = None + self._connect() - @self._writer.addBoth + @self._connection_waiter.addCallback def writer(r): - if isinstance(r, Failure): - r.printTraceback(file=sys.__stderr__) - self._writer = None - self.hs.get_reactor().callLater(1, self._write_loop) + # We have a connection. If we already have a producer, and its + # transport is the same, just trigger a resumeProducing. + if self._producer and r.transport is self._producer.transport: + self._producer.resumeProducing() + self._connection_waiter = None return - try: - for event in self._buffer: - r.transport.write( - dumps(event, ensure_ascii=False, separators=(",", ":")).encode( - "utf8" - ) - ) - r.transport.write(b"\n") - self._buffer.clear() - except Exception as e: - sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),)) - - self._writer = False - self.hs.get_reactor().callLater(1, self._write_loop) + # If the producer is still producing, stop it. + if self._producer: + self._producer.stopProducing() + + # Make a new producer and start it. + self._producer = LogProducer(buffer=self._buffer, transport=r.transport) + r.transport.registerProducer(self._producer, True) + self._producer.resumeProducing() + self._connection_waiter = None def _handle_pressure(self) -> None: """ @@ -277,4 +326,4 @@ class TerseJSONToTCPLogObserver(object): self._logger.failure("Failed clearing backpressure") # Try and write immediately. - self._write_loop() + self._connect() diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 370000e377..8b9c4e38bd 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,13 +23,20 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types -from typing import Any, List +import warnings +from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union + +from typing_extensions import Literal from twisted.internet import defer, threads +if TYPE_CHECKING: + from synapse.logging.scopecontextmanager import _LogContextScope + logger = logging.getLogger(__name__) try: @@ -45,7 +52,7 @@ try: is_thread_resource_usage_supported = True - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return resource.getrusage(RUSAGE_THREAD) @@ -54,7 +61,7 @@ except Exception: # won't track resource usage. is_thread_resource_usage_supported = False - def get_thread_resource_usage(): + def get_thread_resource_usage() -> "Optional[resource._RUsage]": return None @@ -90,7 +97,7 @@ class ContextResourceUsage(object): "evt_db_fetch_count", ] - def __init__(self, copy_from=None): + def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None: """Create a new ContextResourceUsage Args: @@ -100,27 +107,28 @@ class ContextResourceUsage(object): if copy_from is None: self.reset() else: - self.ru_utime = copy_from.ru_utime - self.ru_stime = copy_from.ru_stime - self.db_txn_count = copy_from.db_txn_count + # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now + self.ru_utime = copy_from.ru_utime # type: float + self.ru_stime = copy_from.ru_stime # type: float + self.db_txn_count = copy_from.db_txn_count # type: int - self.db_txn_duration_sec = copy_from.db_txn_duration_sec - self.db_sched_duration_sec = copy_from.db_sched_duration_sec - self.evt_db_fetch_count = copy_from.evt_db_fetch_count + self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float + self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float + self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int - def copy(self): + def copy(self) -> "ContextResourceUsage": return ContextResourceUsage(copy_from=self) - def reset(self): + def reset(self) -> None: self.ru_stime = 0.0 self.ru_utime = 0.0 self.db_txn_count = 0 - self.db_txn_duration_sec = 0 - self.db_sched_duration_sec = 0 + self.db_txn_duration_sec = 0.0 + self.db_sched_duration_sec = 0.0 self.evt_db_fetch_count = 0 - def __repr__(self): + def __repr__(self) -> str: return ( "<ContextResourceUsage ru_stime='%r', ru_utime='%r', " "db_txn_count='%r', db_txn_duration_sec='%r', " @@ -134,7 +142,7 @@ class ContextResourceUsage(object): self.evt_db_fetch_count, ) - def __iadd__(self, other): + def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": """Add another ContextResourceUsage's stats to this one's. Args: @@ -148,7 +156,7 @@ class ContextResourceUsage(object): self.evt_db_fetch_count += other.evt_db_fetch_count return self - def __isub__(self, other): + def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": self.ru_utime -= other.ru_utime self.ru_stime -= other.ru_stime self.db_txn_count -= other.db_txn_count @@ -157,17 +165,67 @@ class ContextResourceUsage(object): self.evt_db_fetch_count -= other.evt_db_fetch_count return self - def __add__(self, other): + def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": res = ContextResourceUsage(copy_from=self) res += other return res - def __sub__(self, other): + def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage": res = ContextResourceUsage(copy_from=self) res -= other return res +LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] + + +class _Sentinel(object): + """Sentinel to represent the root context""" + + __slots__ = ["previous_context", "finished", "request", "scope", "tag"] + + def __init__(self) -> None: + # Minimal set for compatibility with LoggingContext + self.previous_context = None + self.finished = False + self.request = None + self.scope = None + self.tag = None + + def __str__(self): + return "sentinel" + + def copy_to(self, record): + pass + + def copy_to_twisted_log_entry(self, record): + record["request"] = None + record["scope"] = None + + def start(self, rusage: "Optional[resource._RUsage]"): + pass + + def stop(self, rusage: "Optional[resource._RUsage]"): + pass + + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): + pass + + def __nonzero__(self): + return False + + __bool__ = __nonzero__ # python3 + + +SENTINEL_CONTEXT = _Sentinel() + + class LoggingContext(object): """Additional context for log formatting. Contexts are scoped within a "with" block. @@ -189,67 +247,32 @@ class LoggingContext(object): "_resource_usage", "usage_start", "main_thread", - "alive", + "finished", "request", "tag", "scope", ] - thread_local = threading.local() - - class Sentinel(object): - """Sentinel to represent the root context""" - - __slots__ = [] # type: List[Any] - - def __str__(self): - return "sentinel" - - def copy_to(self, record): - pass - - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - - def start(self): - pass - - def stop(self): - pass - - def add_database_transaction(self, duration_sec): - pass - - def add_database_scheduled(self, sched_sec): - pass - - def record_event_fetch(self, event_count): - pass - - def __nonzero__(self): - return False - - __bool__ = __nonzero__ # python3 - - sentinel = Sentinel() - - def __init__(self, name=None, parent_context=None, request=None): - self.previous_context = LoggingContext.current_context() + def __init__(self, name=None, parent_context=None, request=None) -> None: + self.previous_context = current_context() self.name = name # track the resources used by this context so far self._resource_usage = ContextResourceUsage() - # If alive has the thread resource usage when the logcontext last - # became active. - self.usage_start = None + # The thread resource usage when the logcontext became active. None + # if the context is not currently active. + self.usage_start = None # type: Optional[resource._RUsage] self.main_thread = get_thread_id() self.request = None self.tag = "" - self.alive = True - self.scope = None + self.scope = None # type: Optional[_LogContextScope] + + # keep track of whether we have hit the __exit__ block for this context + # (suggesting that the the thing that created the context thinks it should + # be finished, and that re-activating it would suggest an error). + self.finished = False self.parent_context = parent_context @@ -260,76 +283,83 @@ class LoggingContext(object): # the request param overrides the request from the parent context self.request = request - def __str__(self): + def __str__(self) -> str: if self.request: return str(self.request) return "%s@%x" % (self.name, id(self)) @classmethod - def current_context(cls): + def current_context(cls) -> LoggingContextOrSentinel: """Get the current logging context from thread local storage + This exists for backwards compatibility. ``current_context()`` should be + called directly. + Returns: LoggingContext: the current logging context """ - return getattr(cls.thread_local, "current_context", cls.sentinel) + warnings.warn( + "synapse.logging.context.LoggingContext.current_context() is deprecated " + "in favor of synapse.logging.context.current_context().", + DeprecationWarning, + stacklevel=2, + ) + return current_context() @classmethod - def set_current_context(cls, context): + def set_current_context( + cls, context: LoggingContextOrSentinel + ) -> LoggingContextOrSentinel: """Set the current logging context in thread local storage + + This exists for backwards compatibility. ``set_current_context()`` should be + called directly. + Args: context(LoggingContext): The context to activate. Returns: The context that was previously active """ - current = cls.current_context() - - if current is not context: - current.stop() - cls.thread_local.current_context = context - context.start() - return current + warnings.warn( + "synapse.logging.context.LoggingContext.set_current_context() is deprecated " + "in favor of synapse.logging.context.set_current_context().", + DeprecationWarning, + stacklevel=2, + ) + return set_current_context(context) - def __enter__(self): + def __enter__(self) -> "LoggingContext": """Enters this logging context into thread local storage""" - old_context = self.set_current_context(self) + old_context = set_current_context(self) if self.previous_context != old_context: - logger.warn( + logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, ) - self.alive = True - return self - def __exit__(self, type, value, traceback): + def __exit__(self, type, value, traceback) -> None: """Restore the logging context in thread local storage to the state it was before this context was entered. Returns: None to avoid suppressing any exceptions that were thrown. """ - current = self.set_current_context(self.previous_context) + current = set_current_context(self.previous_context) if current is not self: - if current is self.sentinel: + if current is SENTINEL_CONTEXT: logger.warning("Expected logging context %s was lost", self) else: logger.warning( "Expected logging context %s but found %s", self, current ) - self.previous_context = None - self.alive = False - - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage - # reset them in case we get entered again - self._resource_usage.reset() + # the fact that we are here suggests that the caller thinks that everything + # is done and dusted for this logcontext, and further activity will not get + # recorded against the correct metrics. + self.finished = True - def copy_to(self, record): + def copy_to(self, record) -> None: """Copy logging fields from this context to a log record or another LoggingContext """ @@ -340,44 +370,72 @@ class LoggingContext(object): # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record): + def copy_to_twisted_log_entry(self, record) -> None: """ Copy logging fields from this context to a Twisted log record. """ record["request"] = self.request record["scope"] = self.scope - def start(self): + def start(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is currently running. + + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching to this logcontext. May be None if this platform doesn't + support getrusuage. + """ if get_thread_id() != self.main_thread: logger.warning("Started logcontext %s on different thread", self) return + if self.finished: + logger.warning("Re-starting finished log context %s", self) + # If we haven't already started record the thread resource usage so # far - if not self.usage_start: - self.usage_start = get_thread_resource_usage() + if self.usage_start: + logger.warning("Re-starting already-active log context %s", self) + else: + self.usage_start = rusage - def stop(self): - if get_thread_id() != self.main_thread: - logger.warning("Stopped logcontext %s on different thread", self) - return + def stop(self, rusage: "Optional[resource._RUsage]") -> None: + """ + Record that this logcontext is no longer running. - # When we stop, let's record the cpu used since we started - if not self.usage_start: - # Log a warning on platforms that support thread usage tracking - if is_thread_resource_usage_supported: + This should not be called directly: use set_current_context + + Args: + rusage: the resources used by the current thread, at the point of + switching away from this logcontext. May be None if this platform + doesn't support getrusuage. + """ + + try: + if get_thread_id() != self.main_thread: + logger.warning("Stopped logcontext %s on different thread", self) + return + + if not rusage: + return + + # Record the cpu used since we started + if not self.usage_start: logger.warning( - "Called stop on logcontext %s without calling start", self + "Called stop on logcontext %s without recording a start rusage", + self, ) - return + return - utime_delta, stime_delta = self._get_cputime() - self._resource_usage.ru_utime += utime_delta - self._resource_usage.ru_stime += stime_delta - - self.usage_start = None + utime_delta, stime_delta = self._get_cputime(rusage) + self.add_cputime(utime_delta, stime_delta) + finally: + self.usage_start = None - def get_resource_usage(self): + def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. Returns: @@ -390,19 +448,24 @@ class LoggingContext(object): # If we are on the correct thread and we're currently running then we # can include resource usage so far. is_main_thread = get_thread_id() == self.main_thread - if self.alive and self.usage_start and is_main_thread: - utime_delta, stime_delta = self._get_cputime() + if self.usage_start and is_main_thread: + rusage = get_thread_resource_usage() + assert rusage is not None + utime_delta, stime_delta = self._get_cputime(rusage) res.ru_utime += utime_delta res.ru_stime += stime_delta return res - def _get_cputime(self): - """Get the cpu usage time so far + def _get_cputime(self, current: "resource._RUsage") -> Tuple[float, float]: + """Get the cpu usage time between start() and the given rusage + + Args: + rusage: the current resource usage Returns: Tuple[float, float]: seconds in user mode, seconds in system mode """ - current = get_thread_resource_usage() + assert self.usage_start is not None utime_delta = current.ru_utime - self.usage_start.ru_utime stime_delta = current.ru_stime - self.usage_start.ru_stime @@ -426,30 +489,52 @@ class LoggingContext(object): return utime_delta, stime_delta - def add_database_transaction(self, duration_sec): + def add_cputime(self, utime_delta: float, stime_delta: float) -> None: + """Update the CPU time usage of this context (and any parents, recursively). + + Args: + utime_delta: additional user time, in seconds, spent in this context. + stime_delta: additional system time, in seconds, spent in this context. + """ + self._resource_usage.ru_utime += utime_delta + self._resource_usage.ru_stime += stime_delta + if self.parent_context: + self.parent_context.add_cputime(utime_delta, stime_delta) + + def add_database_transaction(self, duration_sec: float) -> None: + """Record the use of a database transaction and the length of time it took. + + Args: + duration_sec: The number of seconds the database transaction took. + """ if duration_sec < 0: raise ValueError("DB txn time can only be non-negative") self._resource_usage.db_txn_count += 1 self._resource_usage.db_txn_duration_sec += duration_sec + if self.parent_context: + self.parent_context.add_database_transaction(duration_sec) - def add_database_scheduled(self, sched_sec): + def add_database_scheduled(self, sched_sec: float) -> None: """Record a use of the database pool Args: - sched_sec (float): number of seconds it took us to get a - connection + sched_sec: number of seconds it took us to get a connection """ if sched_sec < 0: raise ValueError("DB scheduling time can only be non-negative") self._resource_usage.db_sched_duration_sec += sched_sec + if self.parent_context: + self.parent_context.add_database_scheduled(sched_sec) - def record_event_fetch(self, event_count): + def record_event_fetch(self, event_count: int) -> None: """Record a number of events being fetched from the db Args: - event_count (int): number of events being fetched + event_count: number of events being fetched """ self._resource_usage.evt_db_fetch_count += event_count + if self.parent_context: + self.parent_context.record_event_fetch(event_count) class LoggingContextFilter(logging.Filter): @@ -460,15 +545,15 @@ class LoggingContextFilter(logging.Filter): missing fields """ - def __init__(self, **defaults): + def __init__(self, **defaults) -> None: self.defaults = defaults - def filter(self, record): + def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. Returns: True to include the record in the log output. """ - context = LoggingContext.current_context() + context = current_context() for key, value in self.defaults.items(): setattr(record, key, value) @@ -488,26 +573,24 @@ class PreserveLoggingContext(object): __slots__ = ["current_context", "new_context", "has_parent"] - def __init__(self, new_context=None): - if new_context is None: - new_context = LoggingContext.sentinel + def __init__( + self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT + ) -> None: self.new_context = new_context - def __enter__(self): + def __enter__(self) -> None: """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context(self.new_context) + self.current_context = set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None - if not self.current_context.alive: - logger.debug("Entering dead context: %s", self.current_context) - def __exit__(self, type, value, traceback): + def __exit__(self, type, value, traceback) -> None: """Restores the current logging context""" - context = LoggingContext.set_current_context(self.current_context) + context = set_current_context(self.current_context) if context != self.new_context: - if context is LoggingContext.sentinel: + if not context: logger.warning("Expected logging context %s was lost", self.new_context) else: logger.warning( @@ -516,12 +599,42 @@ class PreserveLoggingContext(object): context, ) - if self.current_context is not LoggingContext.sentinel: - if not self.current_context.alive: - logger.debug("Restoring dead context: %s", self.current_context) +_thread_local = threading.local() +_thread_local.current_context = SENTINEL_CONTEXT + + +def current_context() -> LoggingContextOrSentinel: + """Get the current logging context from thread local storage""" + return getattr(_thread_local, "current_context", SENTINEL_CONTEXT) -def nested_logging_context(suffix, parent_context=None): + +def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: + """Set the current logging context in thread local storage + Args: + context(LoggingContext): The context to activate. + Returns: + The context that was previously active + """ + # everything blows up if we allow current_context to be set to None, so sanity-check + # that now. + if context is None: + raise TypeError("'context' argument may not be None") + + current = current_context() + + if current is not context: + rusage = get_thread_resource_usage() + current.stop(rusage) + _thread_local.current_context = context + context.start(rusage) + + return current + + +def nested_logging_context( + suffix: str, parent_context: Optional[LoggingContext] = None +) -> LoggingContext: """Creates a new logging context as a child of another. The nested logging context will have a 'request' made up of the parent context's @@ -542,10 +655,12 @@ def nested_logging_context(suffix, parent_context=None): Returns: LoggingContext: new logging context. """ - if parent_context is None: - parent_context = LoggingContext.current_context() + if parent_context is not None: + context = parent_context # type: LoggingContextOrSentinel + else: + context = current_context() return LoggingContext( - parent_context=parent_context, request=parent_context.request + "-" + suffix + parent_context=context, request=str(context.request) + "-" + suffix ) @@ -567,12 +682,15 @@ def run_in_background(f, *args, **kwargs): yield or await on (for instance because you want to pass it to deferred.gatherResults()). + If f returns a Coroutine object, it will be wrapped into a Deferred (which will have + the side effect of executing the coroutine). + Note that if you completely discard the result, you should make sure that `f` doesn't raise any deferred exceptions, otherwise a scary-looking CRITICAL error about an unhandled error will be logged without much indication about where it came from. """ - current = LoggingContext.current_context() + current = current_context() try: res = f(*args, **kwargs) except: # noqa: E722 @@ -593,7 +711,7 @@ def run_in_background(f, *args, **kwargs): # The function may have reset the context before returning, so # we need to restore it now. - ctx = LoggingContext.set_current_context(current) + ctx = set_current_context(current) # The original context will be restored when the deferred # completes, but there is nothing waiting for it, so it will @@ -612,7 +730,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +743,13 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred @@ -634,14 +760,17 @@ def make_deferred_yieldable(deferred): # ok, we can't be sure that a yield won't block, so let's reset the # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + prev_context = set_current_context(SENTINEL_CONTEXT) deferred.addBoth(_set_context_cb, prev_context) return deferred -def _set_context_cb(result, context): +ResultT = TypeVar("ResultT") + + +def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: """A callback function which just sets the logging context""" - LoggingContext.set_current_context(context) + set_current_context(context) return result @@ -709,7 +838,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): Deferred: A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ - logcontext = LoggingContext.current_context() + logcontext = current_context() def g(): with LoggingContext(parent_context=logcontext): diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 308a27213b..5dddf57008 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -169,7 +169,9 @@ import contextlib import inspect import logging import re +import types from functools import wraps +from typing import TYPE_CHECKING, Dict from canonicaljson import json @@ -177,6 +179,9 @@ from twisted.internet import defer from synapse.config import ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + # Helper class @@ -295,14 +300,11 @@ def _noop_context_manager(*args, **kwargs): # Setup -def init_tracer(config): +def init_tracer(hs: "HomeServer"): """Set the whitelists and initialise the JaegerClient tracer - - Args: - config (HomeserverConfig): The config used by the homeserver """ global opentracing - if not config.opentracer_enabled: + if not hs.config.opentracer_enabled: # We don't have a tracer opentracing = None return @@ -313,18 +315,15 @@ def init_tracer(config): "installed." ) - # Include the worker name - name = config.worker_name if config.worker_name else "master" - # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - set_homeserver_whitelist(config.opentracer_whitelist) + set_homeserver_whitelist(hs.config.opentracer_whitelist) JaegerConfig( - config=config.jaeger_config, - service_name="{} {}".format(config.server_name, name), - scope_manager=LogContextScopeManager(config), + config=hs.config.jaeger_config, + service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + scope_manager=LogContextScopeManager(hs.config), ).initialize_tracer() @@ -547,7 +546,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T return span = opentracing.tracer.active_span - carrier = {} + carrier = {} # type: Dict[str, str] opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -584,7 +583,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True): span = opentracing.tracer.active_span - carrier = {} + carrier = {} # type: Dict[str, str] opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) for key, value in carrier.items(): @@ -639,7 +638,7 @@ def get_active_span_text_map(destination=None): if destination and not whitelisted_homeserver(destination): return {} - carrier = {} + carrier = {} # type: Dict[str, str] opentracing.tracer.inject( opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier ) @@ -653,7 +652,7 @@ def active_span_context_as_string(): Returns: The active span context encoded as a string. """ - carrier = {} + carrier = {} # type: Dict[str, str] if opentracing: opentracing.tracer.inject( opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier @@ -777,8 +776,7 @@ def trace_servlet(servlet_name, extract_context=False): return func @wraps(func) - @defer.inlineCallbacks - def _trace_servlet_inner(request, *args, **kwargs): + async def _trace_servlet_inner(request, *args, **kwargs): request_tags = { "request_id": request.get_request_id(), tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, @@ -795,8 +793,14 @@ def trace_servlet(servlet_name, extract_context=False): scope = start_active_span(servlet_name, tags=request_tags) with scope: - result = yield defer.maybeDeferred(func, request, *args, **kwargs) - return result + result = func(request, *args, **kwargs) + + if not isinstance(result, (types.CoroutineType, defer.Deferred)): + # Some servlets aren't async and just return results + # directly, so we handle that here. + return result + + return await result return _trace_servlet_inner diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index 4eed4f2338..dc3ab00cbb 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager import twisted -from synapse.logging.context import LoggingContext, nested_logging_context +from synapse.logging.context import current_context, nested_logging_context logger = logging.getLogger(__name__) @@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager): (Scope) : the Scope that is active, or None if not available. """ - ctx = LoggingContext.current_context() - if ctx is LoggingContext.sentinel: - return None - else: - return ctx.scope + ctx = current_context() + return ctx.scope def activate(self, span, finish_on_close): """ @@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager): """ enter_logcontext = False - ctx = LoggingContext.current_context() + ctx = current_context() - if ctx is LoggingContext.sentinel: + if not ctx: # We don't want this scope to affect. logger.error("Tried to activate scope outside of loggingcontext") return Scope(None, span) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 7df0fa6087..99049bb5d8 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -20,8 +20,6 @@ import time from functools import wraps from inspect import getcallargs -from six import PY3 - _TIME_FUNC_ID = 0 @@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args): logger = logging.getLogger(name) if logger.isEnabledFor(logging.DEBUG): - if PY3: - lineno = f.__code__.co_firstlineno - pathname = f.__code__.co_filename - else: - lineno = f.func_code.co_firstlineno - pathname = f.func_code.co_filename + lineno = f.__code__.co_firstlineno + pathname = f.__code__.co_filename record = logging.LogRecord( name=name, @@ -119,7 +113,11 @@ def trace_function(f): logger = logging.getLogger(name) level = logging.DEBUG - s = inspect.currentframe().f_back + frame = inspect.currentframe() + if frame is None: + raise Exception("Can't get current frame!") + + s = frame.f_back to_print = [ "\t%s:%s %s. Args: args=%s, kwargs=%s" @@ -144,7 +142,7 @@ def trace_function(f): pathname=pathname, lineno=lineno, msg=msg, - args=None, + args=(), exc_info=None, ) @@ -157,7 +155,12 @@ def trace_function(f): def get_previous_frames(): - s = inspect.currentframe().f_back.f_back + + frame = inspect.currentframe() + if frame is None: + raise Exception("Can't get current frame!") + + s = frame.f_back.f_back to_return = [] while s: if s.f_globals["__name__"].startswith("synapse"): @@ -174,7 +177,10 @@ def get_previous_frames(): def get_previous_frame(ignore=[]): - s = inspect.currentframe().f_back.f_back + frame = inspect.currentframe() + if frame is None: + raise Exception("Can't get current frame!") + s = frame.f_back.f_back while s: if s.f_globals["__name__"].startswith("synapse"): diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index bec3b13397..9cf31f96b3 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -20,13 +20,18 @@ import os import platform import threading import time -from typing import Dict, Union +from typing import Callable, Dict, Iterable, Optional, Tuple, Union import six import attr from prometheus_client import Counter, Gauge, Histogram -from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricFamily +from prometheus_client.core import ( + REGISTRY, + CounterMetricFamily, + GaugeMetricFamily, + HistogramMetricFamily, +) from twisted.internet import reactor @@ -59,10 +64,12 @@ class RegistryProxy(object): @attr.s(hash=True) class LaterGauge(object): - name = attr.ib() - desc = attr.ib() - labels = attr.ib(hash=False) - caller = attr.ib() + name = attr.ib(type=str) + desc = attr.ib(type=str) + labels = attr.ib(hash=False, type=Optional[Iterable[str]]) + # callback: should either return a value (if there are no labels for this metric), + # or dict mapping from a label tuple to a value + caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]]) def collect(self): @@ -125,7 +132,7 @@ class InFlightGauge(object): ) # Counts number of in flight blocks for a given set of label values - self._registrations = {} + self._registrations = {} # type: Dict # Protects access to _registrations self._lock = threading.Lock() @@ -226,7 +233,7 @@ class BucketCollector(object): # Fetch the data -- this must be synchronous! data = self.data_collector() - buckets = {} + buckets = {} # type: Dict[float, int] res = [] for x in data.keys(): @@ -240,7 +247,7 @@ class BucketCollector(object): res.append(["+Inf", sum(data.values())]) metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()]) + self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) ) yield metric @@ -336,6 +343,78 @@ class GCCounts(object): if not running_on_pypy: REGISTRY.register(GCCounts()) + +# +# PyPy GC / memory metrics +# + + +class PyPyGCStats(object): + def collect(self): + + # @stats is a pretty-printer object with __str__() returning a nice table, + # plus some fields that contain data from that table. + # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB'). + stats = gc.get_stats(memory_pressure=False) # type: ignore + # @s contains same fields as @stats, but as actual integers. + s = stats._s # type: ignore + + # also note that field naming is completely braindead + # and only vaguely correlates with the pretty-printed table. + # >>>> gc.get_stats(False) + # Total memory consumed: + # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory + # in arenas: 3.0MB # s.total_arena_memory + # rawmalloced: 1.7MB # s.total_rawmalloced_memory + # nursery: 4.0MB # s.nursery_size + # raw assembler used: 31.0kB # s.jit_backend_used + # ----------------------------- + # Total: 8.8MB # stats.memory_used_sum + # + # Total memory allocated: + # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory + # in arenas: 30.9MB # s.peak_arena_memory + # rawmalloced: 4.1MB # s.peak_rawmalloced_memory + # nursery: 4.0MB # s.nursery_size + # raw assembler allocated: 1.0MB # s.jit_backend_allocated + # ----------------------------- + # Total: 39.7MB # stats.memory_allocated_sum + # + # Total time spent in GC: 0.073 # s.total_gc_time + + pypy_gc_time = CounterMetricFamily( + "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[], + ) + pypy_gc_time.add_metric([], s.total_gc_time / 1000) + yield pypy_gc_time + + pypy_mem = GaugeMetricFamily( + "pypy_memory_bytes", + "Memory tracked by PyPy allocator", + labels=["state", "class", "kind"], + ) + # memory used by JIT assembler + pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used) + pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated) + # memory used by GCed objects + pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory) + pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory) + pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory) + pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory) + pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size) + pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size) + # totals + pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory) + pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory) + pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory) + pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory) + yield pypy_mem + + +if running_on_pypy: + REGISTRY.register(PyPyGCStats()) + + # # Twisted reactor metrics # diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index 74d9c3ecd3..ab7f948ed4 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -33,12 +33,14 @@ from prometheus_client import REGISTRY from twisted.web.resource import Resource +from synapse.util import caches + try: from prometheus_client.samples import Sample except ImportError: - Sample = namedtuple( + Sample = namedtuple( # type: ignore[no-redef] # noqa "Sample", ["name", "labels", "value", "timestamp", "exemplar"] - ) # type: ignore + ) CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") @@ -103,13 +105,15 @@ def nameify_sample(sample): def generate_latest(registry, emit_help=False): - output = [] - for metric in registry.collect(): + # Trigger the cache metrics to be rescraped, which updates the common + # metrics but do not produce metrics themselves + for collector in caches.collectors_by_name.values(): + collector.collect() - if metric.name.startswith("__unused"): - continue + output = [] + for metric in registry.collect(): if not metric.samples: # No samples, don't bother. continue diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index edd6b42db3..13785038ad 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -15,15 +15,20 @@ import logging import threading +from asyncio import iscoroutine +from functools import wraps +from typing import TYPE_CHECKING, Dict, Optional, Set -import six - -from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily +from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext +if TYPE_CHECKING: + import resource + + logger = logging.getLogger(__name__) @@ -33,6 +38,12 @@ _background_process_start_count = Counter( ["name"], ) +_background_process_in_flight_count = Gauge( + "synapse_background_process_in_flight_count", + "Number of background processes in flight", + labelnames=["name"], +) + # we set registry=None in all of these to stop them getting registered with # the default registry. Instead we collect them all via the CustomCollector, # which ensures that we can update them before they are collected. @@ -78,15 +89,19 @@ _background_process_db_sched_duration = Counter( # map from description to a counter, so that we can name our logcontexts # incrementally. (It actually duplicates _background_process_start_count, but # it's much simpler to do so than to try to combine them.) -_background_process_counts = dict() # type: dict[str, int] +_background_process_counts = {} # type: Dict[str, int] -# map from description to the currently running background processes. +# Set of all running background processes that became active active since the +# last time metrics were scraped (i.e. background processes that performed some +# work since the last scrape.) # -# it's kept as a dict of sets rather than a big set so that we can keep track -# of process descriptions that no longer have any active processes. -_background_processes = dict() # type: dict[str, set[_BackgroundProcess]] +# We do it like this to handle the case where we have a large number of +# background processes stacking up behind a lock or linearizer, where we then +# only need to iterate over and update metrics for the process that have +# actually been active and can ignore the idle ones. +_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess] -# A lock that covers the above dicts +# A lock that covers the above set and dict _bg_metrics_lock = threading.Lock() @@ -98,25 +113,16 @@ class _Collector(object): """ def collect(self): - background_process_in_flight_count = GaugeMetricFamily( - "synapse_background_process_in_flight_count", - "Number of background processes in flight", - labels=["name"], - ) + global _background_processes_active_since_last_scrape - # We copy the dict so that it doesn't change from underneath us. - # We also copy the process lists as that can also change + # We swap out the _background_processes set with an empty one so that + # we can safely iterate over the set without holding the lock. with _bg_metrics_lock: - _background_processes_copy = { - k: list(v) for k, v in six.iteritems(_background_processes) - } - - for desc, processes in six.iteritems(_background_processes_copy): - background_process_in_flight_count.add_metric((desc,), len(processes)) - for process in processes: - process.update_metrics() + _background_processes_copy = _background_processes_active_since_last_scrape + _background_processes_active_since_last_scrape = set() - yield background_process_in_flight_count + for process in _background_processes_copy: + process.update_metrics() # now we need to run collect() over each of the static Counters, and # yield each metric they return. @@ -173,7 +179,7 @@ def run_as_background_process(desc, func, *args, **kwargs): Args: desc (str): a description for this background process type - func: a function, which may return a Deferred + func: a function, which may return a Deferred or a coroutine args: positional args for func kwargs: keyword args for func @@ -188,23 +194,83 @@ def run_as_background_process(desc, func, *args, **kwargs): _background_process_counts[desc] = count + 1 _background_process_start_count.labels(desc).inc() + _background_process_in_flight_count.labels(desc).inc() - with LoggingContext(desc) as context: + with BackgroundProcessLoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) - proc = _BackgroundProcess(desc, context) - - with _bg_metrics_lock: - _background_processes.setdefault(desc, set()).add(proc) try: - yield func(*args, **kwargs) + result = func(*args, **kwargs) + + # We probably don't have an ensureDeferred in our call stack to handle + # coroutine results, so we need to ensureDeferred here. + # + # But we need this check because ensureDeferred doesn't like being + # called on immediate values (as opposed to Deferreds or coroutines). + if iscoroutine(result): + result = defer.ensureDeferred(result) + + return (yield result) except Exception: logger.exception("Background process '%s' threw an exception", desc) finally: - proc.update_metrics() - - with _bg_metrics_lock: - _background_processes[desc].remove(proc) + _background_process_in_flight_count.labels(desc).dec() with PreserveLoggingContext(): return run() + + +def wrap_as_background_process(desc): + """Decorator that wraps a function that gets called as a background + process. + + Equivalent of calling the function with `run_as_background_process` + """ + + def wrap_as_background_process_inner(func): + @wraps(func) + def wrap_as_background_process_inner_2(*args, **kwargs): + return run_as_background_process(desc, func, *args, **kwargs) + + return wrap_as_background_process_inner_2 + + return wrap_as_background_process_inner + + +class BackgroundProcessLoggingContext(LoggingContext): + """A logging context that tracks in flight metrics for background + processes. + """ + + __slots__ = ["_proc"] + + def __init__(self, name: str): + super().__init__(name) + + self._proc = _BackgroundProcess(name, self) + + def start(self, rusage: "Optional[resource._RUsage]"): + """Log context has started running (again). + """ + + super().start(rusage) + + # We've become active again so we make sure we're in the list of active + # procs. (Note that "start" here means we've become active, as opposed + # to starting for the first time.) + with _bg_metrics_lock: + _background_processes_active_since_last_scrape.add(self._proc) + + def __exit__(self, type, value, traceback) -> None: + """Log context has finished. + """ + + super().__exit__(type, value, traceback) + + # The background process has finished. We explictly remove and manually + # update the metrics here so that if nothing is scraping metrics the set + # doesn't infinitely grow. + with _bg_metrics_lock: + _background_processes_active_since_last_scrape.discard(self._proc) + + self._proc.update_metrics() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 735b882363..ecdf1ad69f 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,18 +17,27 @@ import logging from twisted.internet import defer +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import UserID +""" +This package defines the 'stable' API which can be used by extension modules which +are loaded into Synapse. +""" + +__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"] + logger = logging.getLogger(__name__) class ModuleApi(object): - """A proxy object that gets passed to password auth providers so they + """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. """ def __init__(self, hs, auth_handler): - self.hs = hs + self._hs = hs self._store = hs.get_datastore() self._auth = hs.get_auth() @@ -64,7 +74,7 @@ class ModuleApi(object): """ if username.startswith("@"): return username - return UserID(username, self.hs.hostname).to_string() + return UserID(username, self._hs.hostname).to_string() def check_user_exists(self, user_id): """Check if user exists. @@ -76,7 +86,7 @@ class ModuleApi(object): Deferred[str|None]: Canonical (case-corrected) user_id, or None if the user is not registered. """ - return self._auth_handler.check_user_exists(user_id) + return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks def register(self, localpart, displayname=None, emails=[]): @@ -111,11 +121,19 @@ class ModuleApi(object): displayname (str|None): The displayname of the new user. emails (List[str]): Emails to bind to the new user. + Raises: + SynapseError if there is an error performing the registration. Check the + 'errcode' property for more information on the reason for failure + Returns: Deferred[str]: user_id """ - return self.hs.get_registration_handler().register_user( - localpart=localpart, default_display_name=displayname, bind_emails=emails + return defer.ensureDeferred( + self._hs.get_registration_handler().register_user( + localpart=localpart, + default_display_name=displayname, + bind_emails=emails, + ) ) def register_device(self, user_id, device_id=None, initial_display_name=None): @@ -131,12 +149,34 @@ class ModuleApi(object): Returns: defer.Deferred[tuple[str, str]]: Tuple of device ID and access token """ - return self.hs.get_registration_handler().register_device( + return self._hs.get_registration_handler().register_device( user_id=user_id, device_id=device_id, initial_display_name=initial_display_name, ) + def record_user_external_id( + self, auth_provider_id: str, remote_user_id: str, registered_user_id: str + ) -> defer.Deferred: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + return self._store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) + + def generate_short_term_login_token( + self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + ) -> str: + """Generate a login token suitable for m.login.token authentication""" + return self._hs.get_macaroon_generator().generate_short_term_login_token( + user_id, duration_in_ms + ) + @defer.inlineCallbacks def invalidate_access_token(self, access_token): """Invalidate an access token for a user @@ -157,10 +197,12 @@ class ModuleApi(object): user_id = user_info["user"].to_string() if device_id: # delete the device, which will also delete its access tokens - yield self.hs.get_device_handler().delete_device(user_id, device_id) + yield self._hs.get_device_handler().delete_device(user_id, device_id) else: # no associated device. Just delete the access token. - yield self._auth_handler.delete_access_token(access_token) + yield defer.ensureDeferred( + self._auth_handler.delete_access_token(access_token) + ) def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection @@ -175,4 +217,42 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.runInteraction(desc, func, *args, **kwargs) + return self._store.db.runInteraction(desc, func, *args, **kwargs) + + def complete_sso_login( + self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str + ): + """Complete a SSO login by redirecting the user to a page to confirm whether they + want their access token sent to `client_redirect_url`, or redirect them to that + URL with a token directly if the URL matches with one of the whitelisted clients. + + This is deprecated in favor of complete_sso_login_async. + + Args: + registered_user_id: The MXID that has been registered as a previous step of + of this SSO login. + request: The request to respond to. + client_redirect_url: The URL to which to offer to redirect the user (or to + redirect them directly if whitelisted). + """ + self._auth_handler._complete_sso_login( + registered_user_id, request, client_redirect_url, + ) + + async def complete_sso_login_async( + self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str + ): + """Complete a SSO login by redirecting the user to a page to confirm whether they + want their access token sent to `client_redirect_url`, or redirect them to that + URL with a token directly if the URL matches with one of the whitelisted clients. + + Args: + registered_user_id: The MXID that has been registered as a previous step of + of this SSO login. + request: The request to respond to. + client_redirect_url: The URL to which to offer to redirect the user (or to + redirect them directly if whitelisted). + """ + await self._auth_handler.complete_sso_login( + registered_user_id, request, client_redirect_url, + ) diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py new file mode 100644 index 0000000000..b15441772c --- /dev/null +++ b/synapse/module_api/errors.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exception types which are exposed as part of the stable module API""" + +from synapse.api.errors import RedirectException, SynapseError # noqa: F401 diff --git a/synapse/notifier.py b/synapse/notifier.py index 4e091314e6..87c120a59c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,11 +15,13 @@ import logging from collections import namedtuple +from typing import Callable, Iterable, List, TypeVar from prometheus_client import Counter from twisted.internet import defer +import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state @@ -40,12 +42,14 @@ users_woken_by_stream_counter = Counter( "synapse_notifier_users_woken_by_stream", "", ["stream"] ) +T = TypeVar("T") + # TODO(paul): Should be shared somewhere -def count(func, l): - """Return the number of items in l for which func returns true.""" +def count(func: Callable[[T], bool], it: Iterable[T]) -> int: + """Return the number of items in it for which func returns true.""" n = 0 - for x in l: + for x in it: if func(x): n += 1 return n @@ -154,16 +158,22 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.user_to_user_stream = {} self.room_to_user_streams = {} self.hs = hs + self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] - self.replication_callbacks = [] + # Called when there are new things to stream over replication + self.replication_callbacks = [] # type: List[Callable[[], None]] + + # Called when remote servers have come back online after having been + # down. + self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -204,7 +214,7 @@ class Notifier(object): "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) - def add_replication_callback(self, cb): + def add_replication_callback(self, cb: Callable[[], None]): """Add a callback that will be called when some new data is available. Callback is not given any arguments. It should *not* return a Deferred - if it needs to do any asynchronous work, a background thread should be started and @@ -265,10 +275,9 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - @defer.inlineCallbacks - def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id): try: - yield self.appservice_handler.notify_interested_services(room_stream_id) + await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") @@ -303,8 +312,7 @@ class Notifier(object): without waking up any of the normal user event streams""" self.notify_replication() - @defer.inlineCallbacks - def wait_for_events( + async def wait_for_events( self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START ): """Wait until the callback returns a non empty response or the @@ -312,9 +320,9 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = yield self.event_sources.get_current_token() + current_token = await self.event_sources.get_current_token() if room_ids is None: - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, @@ -343,11 +351,11 @@ class Notifier(object): self.hs.get_reactor(), ) with PreserveLoggingContext(): - yield listener.deferred + await listener.deferred current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) if result: break @@ -363,12 +371,11 @@ class Notifier(object): # This happened if there was no timeout or if the timeout had # already expired. current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) return result - @defer.inlineCallbacks - def get_events_for( + async def get_events_for( self, user, pagination_config, @@ -390,15 +397,14 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = yield self.event_sources.get_current_token() + from_token = await self.event_sources.get_current_token() limit = pagination_config.limit - room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id) + room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - @defer.inlineCallbacks - def check_for_updates(before_token, after_token): + async def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) @@ -414,7 +420,7 @@ class Notifier(object): if only_keys and name not in only_keys: continue - new_events, new_key = yield source.get_new_events( + new_events, new_key = await source.get_new_events( user=user, from_key=getattr(from_token, keyname), limit=limit, @@ -424,8 +430,11 @@ class Notifier(object): ) if name == "room": - new_events = yield filter_events_for_client( - self.store, user.to_string(), new_events, is_peeking=is_peeking + new_events = await filter_events_for_client( + self.storage, + user.to_string(), + new_events, + is_peeking=is_peeking, ) elif name == "presence": now = self.clock.time_msec() @@ -457,7 +466,7 @@ class Notifier(object): user_id_for_stream, ) - result = yield self.wait_for_events( + result = await self.wait_for_events( user_id_for_stream, timeout, check_for_updates, @@ -467,20 +476,18 @@ class Notifier(object): return result - @defer.inlineCallbacks - def _get_room_ids(self, user, explicit_room_id): - joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) + async def _get_room_ids(self, user, explicit_room_id): + joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: return [explicit_room_id], True - if (yield self._is_world_readable(explicit_room_id)): + if await self._is_world_readable(explicit_room_id): return [explicit_room_id], False raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - @defer.inlineCallbacks - def _is_world_readable(self, room_id): - state = yield self.state_handler.get_current_state( + async def _is_world_readable(self, room_id): + state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: @@ -521,3 +528,12 @@ class Notifier(object): """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() + + def notify_remote_server_up(self, server: str): + """Notify any replication that a remote server has come back up + """ + # We call federation_sender directly rather than registering as a + # callback as a) we already have a reference to it and b) it introduces + # circular dependencies. + if self.federation_sender: + self.federation_sender.wake_destination(server) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 22491f3700..e75d964ac8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache( "cache", "push_rules_delta_state_cache_metric", cache=[], # Meaningless size, as this isn't a cache that stores values + resizable=False, ) @@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", - cache=[], # Meaningless size, as this isn't a cache that stores values + cache=[], # Meaningless size, as this isn't a cache that stores values, + resizable=False, ) @defer.inlineCallbacks @@ -79,7 +81,7 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + rules_for_room = yield self._get_rules_for_room(room_id) rules_by_user = yield rules_for_room.get_rules(event, context) @@ -116,7 +118,7 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and @@ -149,9 +151,10 @@ class BulkPushRuleEvaluator(object): room_members = yield self.store.get_joined_users_from_context(event, context) - (power_levels, sender_power_level) = ( - yield self._get_power_levels_and_sender_level(event, context) - ) + ( + power_levels, + sender_power_level, + ) = yield self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels @@ -303,7 +306,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) @@ -385,15 +388,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids.values(), - retcols=("user_id", "membership", "event_id"), - keyvalues={}, - batch_size=500, - desc="_get_rules_for_member_event_ids", - ) + rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} @@ -407,11 +402,11 @@ class RulesForRoom(object): if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - interested_in_user_ids = set( + interested_in_user_ids = { user_id for user_id, membership in itervalues(members) if membership == Membership.JOIN - ) + } logger.debug("Joined: %r", interested_in_user_ids) @@ -419,9 +414,9 @@ class RulesForRoom(object): interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) - user_ids = set( + user_ids = { uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher - ) + } logger.debug("With pushers: %r", user_ids) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 42e5b0c0a5..568c13eaea 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -15,7 +15,6 @@ import logging -from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.metrics.background_process_metrics import run_as_background_process @@ -132,8 +131,7 @@ class EmailPusher(object): self._is_processing = False self._start_processing() - @defer.inlineCallbacks - def _process(self): + async def _process(self): # we should never get here if we are already processing assert not self._is_processing @@ -142,7 +140,7 @@ class EmailPusher(object): if self.throttle_params is None: # this is our first loop: load up the throttle params - self.throttle_params = yield self.store.get_throttle_params_by_room( + self.throttle_params = await self.store.get_throttle_params_by_room( self.pusher_id ) @@ -151,7 +149,7 @@ class EmailPusher(object): while True: starting_max_ordering = self.max_stream_ordering try: - yield self._unsafe_process() + await self._unsafe_process() except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: @@ -159,8 +157,7 @@ class EmailPusher(object): finally: self._is_processing = False - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): """ Main logic of the push loop without the wrapper function that sets up logging, measures and guards against multiple instances of it @@ -168,12 +165,12 @@ class EmailPusher(object): """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering fn = self.store.get_unread_push_actions_for_user_in_range_for_email - unprocessed = yield fn(self.user_id, start, self.max_stream_ordering) + unprocessed = await fn(self.user_id, start, self.max_stream_ordering) soonest_due_at = None if not unprocessed: - yield self.save_last_stream_ordering_and_success(self.max_stream_ordering) + await self.save_last_stream_ordering_and_success(self.max_stream_ordering) return for push_action in unprocessed: @@ -201,15 +198,15 @@ class EmailPusher(object): "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]), } - yield self.send_notification(unprocessed, reason) + await self.send_notification(unprocessed, reason) - yield self.save_last_stream_ordering_and_success( - max([ea["stream_ordering"] for ea in unprocessed]) + await self.save_last_stream_ordering_and_success( + max(ea["stream_ordering"] for ea in unprocessed) ) # we update the throttle on all the possible unprocessed push actions for ea in unprocessed: - yield self.sent_notif_update_throttle(ea["room_id"], ea) + await self.sent_notif_update_throttle(ea["room_id"], ea) break else: if soonest_due_at is None or should_notify_at < soonest_due_at: @@ -227,21 +224,18 @@ class EmailPusher(object): self.seconds_until(soonest_due_at), self.on_timer ) - @defer.inlineCallbacks - def save_last_stream_ordering_and_success(self, last_stream_ordering): + async def save_last_stream_ordering_and_success(self, last_stream_ordering): if last_stream_ordering is None: # This happens if we haven't yet processed anything return self.last_stream_ordering = last_stream_ordering - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -277,13 +271,12 @@ class EmailPusher(object): may_send_at = last_sent_ts + throttle_ms return may_send_at - @defer.inlineCallbacks - def sent_notif_update_throttle(self, room_id, notified_push_action): + async def sent_notif_update_throttle(self, room_id, notified_push_action): # We have sent a notification, so update the throttle accordingly. # If the event that triggered the notif happened more than # THROTTLE_RESET_AFTER_MS after the previous one that triggered a # notif, we release the throttle. Otherwise, the throttle is increased. - time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before( + time_of_previous_notifs = await self.store.get_time_of_last_push_action_before( notified_push_action["stream_ordering"] ) @@ -312,14 +305,13 @@ class EmailPusher(object): "last_sent_ts": self.clock.time_msec(), "throttle_ms": new_throttle_ms, } - yield self.store.set_throttle_params( + await self.store.set_throttle_params( self.pusher_id, room_id, self.throttle_params[room_id] ) - @defer.inlineCallbacks - def send_notification(self, push_actions, reason): + async def send_notification(self, push_actions, reason): logger.info("Sending notif email for user %r", self.user_id) - yield self.mailer.send_notification_mail( + await self.mailer.send_notification_mail( self.app_id, self.user_id, self.email, push_actions, reason ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..eaaa7afc91 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -import six - from prometheus_client import Counter from twisted.internet import defer @@ -28,9 +26,6 @@ from synapse.push import PusherConfigException from . import push_rule_evaluator, push_tools -if six.PY3: - long = int - logger = logging.getLogger(__name__) http_push_processed_counter = Counter( @@ -64,6 +59,7 @@ class HttpPusher(object): def __init__(self, hs, pusherdict): self.hs = hs self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() self.user_id = pusherdict["user_name"] @@ -102,7 +98,7 @@ class HttpPusher(object): if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") self.url = self.data["url"] - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] @@ -210,14 +206,12 @@ class HttpPusher(object): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -246,8 +240,8 @@ class HttpPusher(object): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn( - "Giving up on a notification to user %s, " "pushkey %s", + logger.warning( + "Giving up on a notification to user %s, pushkey %s", self.user_id, self.pushkey, ) @@ -299,9 +293,8 @@ class HttpPusher(object): if pk != self.pushkey: # for sanity, we only remove the pushkey if it # was the one we actually sent... - logger.warn( - ("Ignoring rejected pushkey %s because we" " didn't send it"), - pk, + logger.warning( + ("Ignoring rejected pushkey %s because we didn't send it"), pk, ) else: logger.info("Pushkey %s was rejected: removing", pk) @@ -320,7 +313,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], @@ -329,7 +322,7 @@ class HttpPusher(object): return d ctx = yield push_tools.get_context_for_event( - self.store, self.state_handler, event, self.user_id + self.storage, self.state_handler, event, self.user_id ) d = { @@ -349,7 +342,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, "tweaks": tweaks, } @@ -400,7 +393,7 @@ class HttpPusher(object): Args: badge (int): number of unread messages """ - logger.info("Sending updated badge count %d to %s", badge, self.name) + logger.debug("Sending updated badge count %d to %s", badge, self.name) d = { "notification": { "id": "", @@ -411,7 +404,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3dfd527849..d57a66a697 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -19,14 +19,13 @@ import logging import time from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Iterable, List, TypeVar from six.moves import urllib import bleach import jinja2 -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -41,9 +40,11 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) +T = TypeVar("T") + MESSAGE_FROM_PERSON_IN_ROOM = ( - "You have a message on %(app)s from %(person)s " "in the %(room)s room..." + "You have a message on %(app)s from %(person)s in the %(room)s room..." ) MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." @@ -55,7 +56,7 @@ MESSAGES_FROM_PERSON_AND_OTHERS = ( "You have messages on %(app)s from %(person)s and others..." ) INVITE_FROM_PERSON_TO_ROOM = ( - "%(person)s has invited you to join the " "%(room)s room on %(app)s..." + "%(person)s has invited you to join the %(room)s room on %(app)s..." ) INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." @@ -119,12 +120,12 @@ class Mailer(object): self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() + self.storage = hs.get_storage() self.app_name = app_name logger.info("Created Mailer for app_name %s" % app_name) - @defer.inlineCallbacks - def send_password_reset_mail(self, email_address, token, client_secret, sid): + async def send_password_reset_mail(self, email_address, token, client_secret, sid): """Send an email with a password reset link to a user Args: @@ -136,22 +137,22 @@ class Mailer(object): group together multiple email sending attempts sid (str): The generated session ID """ + params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( self.hs.config.public_baseurl - + "_matrix/client/unstable/password_reset/email/submit_token" - "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid) + + "_matrix/client/unstable/password_reset/email/submit_token?%s" + % urllib.parse.urlencode(params) ) template_vars = {"link": link} - yield self.send_email( + await self.send_email( email_address, "[%s] Password Reset" % self.hs.config.server_name, template_vars, ) - @defer.inlineCallbacks - def send_registration_mail(self, email_address, token, client_secret, sid): + async def send_registration_mail(self, email_address, token, client_secret, sid): """Send an email with a registration confirmation link to a user Args: @@ -163,28 +164,56 @@ class Mailer(object): group together multiple email sending attempts sid (str): The generated session ID """ + params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( self.hs.config.public_baseurl - + "_matrix/client/unstable/registration/email/submit_token" - "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid) + + "_matrix/client/unstable/registration/email/submit_token?%s" + % urllib.parse.urlencode(params) ) template_vars = {"link": link} - yield self.send_email( + await self.send_email( email_address, "[%s] Register your Email Address" % self.hs.config.server_name, template_vars, ) - @defer.inlineCallbacks - def send_notification_mail( + async def send_add_threepid_mail(self, email_address, token, client_secret, sid): + """Send an email with a validation link to a user for adding a 3pid to their account + + Args: + email_address (str): Email address we're sending the validation link to + + token (str): Unique token generated by the server to verify the email was received + + client_secret (str): Unique token generated by the client to group together + multiple email sending attempts + + sid (str): The generated session ID + """ + params = {"token": token, "client_secret": client_secret, "sid": sid} + link = ( + self.hs.config.public_baseurl + + "_matrix/client/unstable/add_threepid/email/submit_token?%s" + % urllib.parse.urlencode(params) + ) + + template_vars = {"link": link} + + await self.send_email( + email_address, + "[%s] Validate Your Email" % self.hs.config.server_name, + template_vars, + ) + + async def send_notification_mail( self, app_id, user_id, email_address, push_actions, reason ): """Send email regarding a user's room notifications""" rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) - notif_events = yield self.store.get_events( + notif_events = await self.store.get_events( [pa["event_id"] for pa in push_actions] ) @@ -197,7 +226,7 @@ class Mailer(object): state_by_room = {} try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -205,14 +234,13 @@ class Mailer(object): except StoreError: user_display_name = user_id - @defer.inlineCallbacks - def _fetch_room_state(room_id): - room_state = yield self.store.get_current_state_ids(room_id) + async def _fetch_room_state(room_id): + room_state = await self.store.get_current_state_ids(room_id) state_by_room[room_id] = room_state # Run at most 3 of these at once: sync does 10 at a time but email # notifs are much less realtime than sync so we can afford to wait a bit. - yield concurrently_execute(_fetch_room_state, rooms_in_order, 3) + await concurrently_execute(_fetch_room_state, rooms_in_order, 3) # actually sort our so-called rooms_in_order list, most recent room first rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) @@ -220,19 +248,19 @@ class Mailer(object): rooms = [] for r in rooms_in_order: - roomvars = yield self.get_room_vars( + roomvars = await self.get_room_vars( r, user_id, notifs_by_room[r], notif_events, state_by_room[r] ) rooms.append(roomvars) - reason["room_name"] = yield calculate_room_name( + reason["room_name"] = await calculate_room_name( self.store, state_by_room[reason["room_id"]], user_id, fallback_to_members=True, ) - summary_text = yield self.make_summary_text( + summary_text = await self.make_summary_text( notifs_by_room, state_by_room, notif_events, user_id, reason ) @@ -247,12 +275,11 @@ class Mailer(object): "reason": reason, } - yield self.send_email( + await self.send_email( email_address, "[%s] %s" % (self.app_name, summary_text), template_vars ) - @defer.inlineCallbacks - def send_email(self, email_address, subject, template_vars): + async def send_email(self, email_address, subject, template_vars): """Send an email with the given information and template text""" try: from_string = self.hs.config.email_notif_from % {"app": self.app_name} @@ -280,9 +307,9 @@ class Mailer(object): multipart_msg.attach(text_part) multipart_msg.attach(html_part) - logger.info("Sending email notification to %s" % email_address) + logger.info("Sending email to %s" % email_address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, raw_from, @@ -297,13 +324,14 @@ class Mailer(object): ) ) - @defer.inlineCallbacks - def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): + async def get_room_vars( + self, room_id, user_id, notifs, notif_events, room_state_ids + ): my_member_event_id = room_state_ids[("m.room.member", user_id)] - my_member_event = yield self.store.get_event(my_member_event_id) + my_member_event = await self.store.get_event(my_member_event_id) is_invite = my_member_event.content["membership"] == "invite" - room_name = yield calculate_room_name(self.store, room_state_ids, user_id) + room_name = await calculate_room_name(self.store, room_state_ids, user_id) room_vars = { "title": room_name, @@ -315,7 +343,7 @@ class Mailer(object): if not is_invite: for n in notifs: - notifvars = yield self.get_notif_vars( + notifvars = await self.get_notif_vars( n, user_id, notif_events[n["event_id"]], room_state_ids ) @@ -342,9 +370,8 @@ class Mailer(object): return room_vars - @defer.inlineCallbacks - def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): - results = yield self.store.get_events_around( + async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): + results = await self.store.get_events_around( notif["room_id"], notif["event_id"], before_limit=CONTEXT_BEFORE, @@ -357,25 +384,24 @@ class Mailer(object): "messages": [], } - the_events = yield filter_events_for_client( - self.store, user_id, results["events_before"] + the_events = await filter_events_for_client( + self.storage, user_id, results["events_before"] ) the_events.append(notif_event) for event in the_events: - messagevars = yield self.get_message_vars(notif, event, room_state_ids) + messagevars = await self.get_message_vars(notif, event, room_state_ids) if messagevars is not None: ret["messages"].append(messagevars) return ret - @defer.inlineCallbacks - def get_message_vars(self, notif, event, room_state_ids): + async def get_message_vars(self, notif, event, room_state_ids): if event.type != EventTypes.Message: return sender_state_event_id = room_state_ids[("m.room.member", event.sender)] - sender_state_event = yield self.store.get_event(sender_state_event_id) + sender_state_event = await self.store.get_event(sender_state_event_id) sender_name = name_from_member_event(sender_state_event) sender_avatar_url = sender_state_event.content.get("avatar_url") @@ -425,8 +451,7 @@ class Mailer(object): return messagevars - @defer.inlineCallbacks - def make_summary_text( + async def make_summary_text( self, notifs_by_room, room_state_ids, notif_events, user_id, reason ): if len(notifs_by_room) == 1: @@ -436,17 +461,17 @@ class Mailer(object): # If the room has some kind of name, use it, but we don't # want the generated-from-names one here otherwise we'll # end up with, "new message from Bob in the Bob room" - room_name = yield calculate_room_name( + room_name = await calculate_room_name( self.store, room_state_ids[room_id], user_id, fallback_to_members=False ) my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)] - my_member_event = yield self.store.get_event(my_member_event_id) + my_member_event = await self.store.get_event(my_member_event_id) if my_member_event.content["membership"] == "invite": inviter_member_event_id = room_state_ids[room_id][ ("m.room.member", my_member_event.sender) ] - inviter_member_event = yield self.store.get_event( + inviter_member_event = await self.store.get_event( inviter_member_event_id ) inviter_name = name_from_member_event(inviter_member_event) @@ -471,7 +496,7 @@ class Mailer(object): state_event_id = room_state_ids[room_id][ ("m.room.member", event.sender) ] - state_event = yield self.store.get_event(state_event_id) + state_event = await self.store.get_event(state_event_id) sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: @@ -494,15 +519,13 @@ class Mailer(object): # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + } ) - member_events = yield self.store.get_events( + member_events = await self.store.get_events( [ room_state_ids[room_id][("m.room.member", s)] for s in sender_ids @@ -525,16 +548,16 @@ class Mailer(object): else: # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" + room_id = reason["room_id"] + sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[reason["room_id"]] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + } ) - member_events = yield self.store.get_events( + member_events = await self.store.get_events( [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] ) @@ -608,10 +631,10 @@ def safe_text(raw_text): ) -def deduped_ordered_list(l): +def deduped_ordered_list(it: Iterable[T]) -> List[T]: seen = set() ret = [] - for item in l: + for item in it: if item not in seen: seen.add(item) ret.append(item) diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 16a7e8e31d..0644a13cfc 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -18,6 +18,8 @@ import re from twisted.internet import defer +from synapse.api.constants import EventTypes + logger = logging.getLogger(__name__) # intentionally looser than what aliases we allow to be registered since @@ -50,17 +52,17 @@ def calculate_room_name( (string or None) A human readable name for the room. """ # does it have a name? - if ("m.room.name", "") in room_state_ids: + if (EventTypes.Name, "") in room_state_ids: m_room_name = yield store.get_event( - room_state_ids[("m.room.name", "")], allow_none=True + room_state_ids[(EventTypes.Name, "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: return m_room_name.content["name"] # does it have a canonical alias? - if ("m.room.canonical_alias", "") in room_state_ids: + if (EventTypes.CanonicalAlias, "") in room_state_ids: canon_alias = yield store.get_event( - room_state_ids[("m.room.canonical_alias", "")], allow_none=True + room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True ) if ( canon_alias @@ -74,32 +76,22 @@ def calculate_room_name( # for an event type, so rearrange the data structure room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) - # right then, any aliases at all? - if "m.room.aliases" in room_state_bytype_ids: - m_room_aliases = room_state_bytype_ids["m.room.aliases"] - for alias_id in m_room_aliases.values(): - alias_event = yield store.get_event(alias_id, allow_none=True) - if alias_event and alias_event.content.get("aliases"): - the_aliases = alias_event.content["aliases"] - if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): - return the_aliases[0] - if not fallback_to_members: return None my_member_event = None - if ("m.room.member", user_id) in room_state_ids: + if (EventTypes.Member, user_id) in room_state_ids: my_member_event = yield store.get_event( - room_state_ids[("m.room.member", user_id)], allow_none=True + room_state_ids[(EventTypes.Member, user_id)], allow_none=True ) if ( my_member_event is not None and my_member_event.content["membership"] == "invite" ): - if ("m.room.member", my_member_event.sender) in room_state_ids: + if (EventTypes.Member, my_member_event.sender) in room_state_ids: inviter_member_event = yield store.get_event( - room_state_ids[("m.room.member", my_member_event.sender)], + room_state_ids[(EventTypes.Member, my_member_event.sender)], allow_none=True, ) if inviter_member_event: @@ -114,9 +106,9 @@ def calculate_room_name( # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. - if "m.room.member" in room_state_bytype_ids: + if EventTypes.Member in room_state_bytype_ids: member_events = yield store.get_events( - list(room_state_bytype_ids["m.room.member"].values()) + list(room_state_bytype_ids[EventTypes.Member].values()) ) all_members = [ ev @@ -138,9 +130,9 @@ def calculate_room_name( # self-chat, peeked room with 1 participant, # or inbound invite, or outbound 3PID invite. if all_members[0].sender == user_id: - if "m.room.third_party_invite" in room_state_bytype_ids: + if EventTypes.ThirdPartyInvite in room_state_bytype_ids: third_party_invites = room_state_bytype_ids[ - "m.room.third_party_invite" + EventTypes.ThirdPartyInvite ].values() if len(third_party_invites) > 0: diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 5ed9147de4..11032491af 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,11 +16,13 @@ import logging import re +from typing import Pattern from six import string_types +from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -56,18 +58,18 @@ def _test_ineq_condition(condition, number): rhs = m.group(2) if not rhs.isdigit(): return False - rhs = int(rhs) + rhs_int = int(rhs) if ineq == "" or ineq == "==": - return number == rhs + return number == rhs_int elif ineq == "<": - return number < rhs + return number < rhs_int elif ineq == ">": - return number > rhs + return number > rhs_int elif ineq == ">=": - return number >= rhs + return number >= rhs_int elif ineq == "<=": - return number <= rhs + return number <= rhs_int else: return False @@ -83,7 +85,13 @@ def tweaks_for_actions(actions): class PushRuleEvaluatorForEvent(object): - def __init__(self, event, room_member_count, sender_power_level, power_levels): + def __init__( + self, + event: EventBase, + room_member_count: int, + sender_power_level: int, + power_levels: dict, + ): self._event = event self._room_member_count = room_member_count self._sender_power_level = sender_power_level @@ -92,7 +100,7 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name): + def matches(self, condition: dict, user_id: str, display_name: str) -> bool: if condition["kind"] == "event_match": return self._event_match(condition, user_id) elif condition["kind"] == "contains_display_name": @@ -106,7 +114,7 @@ class PushRuleEvaluatorForEvent(object): else: return True - def _event_match(self, condition, user_id): + def _event_match(self, condition: dict, user_id: str) -> bool: pattern = condition.get("pattern", None) if not pattern: @@ -117,7 +125,7 @@ class PushRuleEvaluatorForEvent(object): pattern = UserID.from_string(user_id).localpart if not pattern: - logger.warn("event_match condition with no pattern") + logger.warning("event_match condition with no pattern") return False # XXX: optimisation: cache our pattern regexps @@ -134,7 +142,7 @@ class PushRuleEvaluatorForEvent(object): return _glob_matches(pattern, haystack) - def _contains_display_name(self, display_name): + def _contains_display_name(self, display_name: str) -> bool: if not display_name: return False @@ -142,51 +150,52 @@ class PushRuleEvaluatorForEvent(object): if not body: return False - return _glob_matches(display_name, body, word_boundary=True) + # Similar to _glob_matches, but do not treat display_name as a glob. + r = regex_cache.get((display_name, False, True), None) + if not r: + r = re.escape(display_name) + r = _re_word_boundary(r) + r = re.compile(r, flags=re.IGNORECASE) + regex_cache[(display_name, False, True)] = r + + return r.search(body) - def _get_value(self, dotted_key): + def _get_value(self, dotted_key: str) -> str: return self._value_cache.get(dotted_key, None) -# Caches (glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches +regex_cache = LruCache(50000) register_cache("cache", "regex_push_cache", regex_cache) -def _glob_matches(glob, value, word_boundary=False): +def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: """Tests if value matches glob. Args: - glob (string) - value (string): String to test against glob. - word_boundary (bool): Whether to match against word boundaries or entire + glob + value: String to test against glob. + word_boundary: Whether to match against word boundaries or entire string. Defaults to False. - - Returns: - bool """ try: - r = regex_cache.get((glob, word_boundary), None) + r = regex_cache.get((glob, True, word_boundary), None) if not r: r = _glob_to_re(glob, word_boundary) - regex_cache[(glob, word_boundary)] = r + regex_cache[(glob, True, word_boundary)] = r return r.search(value) except re.error: - logger.warn("Failed to parse glob to regex: %r", glob) + logger.warning("Failed to parse glob to regex: %r", glob) return False -def _glob_to_re(glob, word_boundary): +def _glob_to_re(glob: str, word_boundary: bool) -> Pattern: """Generates regex for a given glob. Args: - glob (string) - word_boundary (bool): Whether to match against word boundaries or entire - string. Defaults to False. - - Returns: - regex object + glob + word_boundary: Whether to match against word boundaries or entire string. """ if IS_GLOB.search(glob): r = re.escape(glob) @@ -219,7 +228,7 @@ def _glob_to_re(glob, word_boundary): return re.compile(r, flags=re.IGNORECASE) -def _re_word_boundary(r): +def _re_word_boundary(r: str) -> str: """ Adds word boundary characters to the start and end of an expression to require that the match occur as a whole word, diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a54051a726..5dae4648c0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,11 +16,12 @@ from twisted.internet import defer from synapse.push.presentable_names import calculate_room_name, name_from_member_event +from synapse.storage import Storage @defer.inlineCallbacks def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_user(user_id) + invites = yield store.get_invited_rooms_for_local_user(user_id) joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") @@ -43,22 +44,22 @@ def get_badge_count(store, user_id): @defer.inlineCallbacks -def get_context_for_event(store, state_handler, ev, user_id): +def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield store.get_state_ids_for_event(ev.event_id) + room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room name = yield calculate_room_name( - store, room_state_ids, user_id, fallback_to_single_member=False + storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield store.get_event(sender_state_event_id) + sender_state_event = yield storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index f277aeb131..8ad0bf5936 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -80,9 +80,11 @@ class PusherFactory(object): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if "data" in pusherdict and "brand" in pusherdict["data"]: - app_name = pusherdict["data"]["brand"] - else: - app_name = self.config.email_app_name + data = pusherdict["data"] - return app_name + if isinstance(data, dict): + brand = data.get("brand") + if isinstance(brand, str): + return brand + + return self.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 08e840fdc2..88d203aa44 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -15,11 +15,17 @@ # limitations under the License. import logging +from collections import defaultdict +from threading import Lock +from typing import Dict, Tuple, Union from twisted.internet import defer +from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException +from synapse.push.emailpusher import EmailPusher +from synapse.push.httppusher import HttpPusher from synapse.push.pusher import PusherFactory from synapse.util.async_helpers import concurrently_execute @@ -47,7 +53,29 @@ class PusherPool: self._should_start_pushers = _hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.pushers = {} + + # map from user id to app_id:pushkey to pusher + self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] + + # a lock for the pushers dict, since `count_pushers` is called from an different + # and we otherwise get concurrent modification errors + self._pushers_lock = Lock() + + def count_pushers(): + results = defaultdict(int) # type: Dict[Tuple[str, str], int] + with self._pushers_lock: + for pushers in self.pushers.values(): + for pusher in pushers.values(): + k = (type(pusher).__name__, pusher.app_id) + results[k] += 1 + return results + + LaterGauge( + name="synapse_pushers", + desc="the number of active pushers", + labels=["kind", "app_id"], + caller=count_pushers, + ) def start(self): """Starts the pushers off in a background process. @@ -103,9 +131,7 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = ( - yield self.store.get_latest_push_action_stream_ordering() - ) + last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() yield self.store.add_pusher( user_id=user_id, @@ -193,7 +219,7 @@ class PusherPool: min_stream_id - 1, max_stream_id ) # This returns a tuple, user_id is at index 3 - users_affected = set([r[3] for r in updated_receipts]) + users_affected = {r[3] for r in updated_receipts} for u in users_affected: if u in self.pushers: @@ -234,7 +260,6 @@ class PusherPool: Deferred """ pushers = yield self.store.get_all_pushers() - logger.info("Starting %d pushers", len(pushers)) # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -247,7 +272,7 @@ class PusherPool: """Start the given pusher Args: - pusherdict (dict): + pusherdict (dict): dict with the values pulled from the db table Returns: Deferred[EmailPusher|HttpPusher] @@ -256,7 +281,8 @@ class PusherPool: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: logger.warning( - "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", + "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", + pusherdict["id"], pusherdict.get("user_name"), pusherdict.get("app_id"), pusherdict.get("pushkey"), @@ -264,18 +290,21 @@ class PusherPool: ) return except Exception: - logger.exception("Couldn't start a pusher: caught Exception") + logger.exception( + "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + ) return if not p: return appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p + with self._pushers_lock: + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to @@ -304,7 +333,9 @@ class PusherPool: if appid_pushkey in byuser: logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) byuser[appid_pushkey].on_stop() - del byuser[appid_pushkey] + with self._pushers_lock: + del byuser[appid_pushkey] + yield self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 0bd563edc7..8b4312e5a3 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -1,6 +1,7 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +16,7 @@ # limitations under the License. import logging -from typing import Set +from typing import List, Set from pkg_resources import ( DistributionNotFound, @@ -43,7 +44,8 @@ REQUIREMENTS = [ "frozendict>=1", "unpaddedbase64>=1.1.0", "canonicaljson>=1.1.3", - "signedjson>=1.0.0", + # we use the type definitions added in signedjson 1.1. + "signedjson>=1.1.0", "pynacl>=1.2.1", "idna>=2.5", # validating SSL certs for IP addresses requires service_identity 18.1. @@ -61,7 +63,6 @@ REQUIREMENTS = [ "bcrypt>=3.1.0", "pillow>=4.3.0", "sortedcontainers>=1.4.4", - "psutil>=2.0.0", "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", @@ -73,6 +74,7 @@ REQUIREMENTS = [ "netaddr>=0.7.18", "Jinja2>=2.9", "bleach>=1.4.3", + "typing-extensions>=3.7.4", ] CONDITIONAL_REQUIREMENTS = { @@ -90,12 +92,16 @@ CONDITIONAL_REQUIREMENTS = { 'eliot<1.8.0;python_version<"3.5.3"', ], "saml2": ["pysaml2>=4.5.0"], + "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], "test": ["mock>=2.0", "parameterized"], "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], + # hiredis is not a *strict* dependency, but it makes things much faster. + # (if it is not installed, we fall back to slow code.) + "redis": ["txredisapi>=1.4.7", "hiredis"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] @@ -144,7 +150,11 @@ def check_requirements(for_feature=None): deps_needed.append(dependency) errors.append( "Needed %s, got %s==%s" - % (dependency, e.dist.project_name, e.dist.version) + % ( + dependency, + e.dist.project_name, # type: ignore[attr-defined] # noqa + e.dist.version, # type: ignore[attr-defined] # noqa + ) ) except DistributionNotFound: deps_needed.append(dependency) @@ -159,7 +169,7 @@ def check_requirements(for_feature=None): if not for_feature: # Check the optional dependencies are up to date. We allow them to not be # installed. - OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) + OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str] for dependency in OPTS: try: @@ -168,7 +178,11 @@ def check_requirements(for_feature=None): deps_needed.append(dependency) errors.append( "Needed optional %s, got %s==%s" - % (dependency, e.dist.project_name, e.dist.version) + % ( + dependency, + e.dist.project_name, # type: ignore[attr-defined] # noqa + e.dist.version, # type: ignore[attr-defined] # noqa + ) ) except DistributionNotFound: # If it's not found, we don't care diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 81b85352b1..19b69e0e11 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -14,7 +14,16 @@ # limitations under the License. from synapse.http.server import JsonResource -from synapse.replication.http import federation, login, membership, register, send_event +from synapse.replication.http import ( + devices, + federation, + login, + membership, + presence, + register, + send_event, + streams, +) REPLICATION_PREFIX = "/_synapse/replication" @@ -26,7 +35,13 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs): send_event.register_servlets(hs, self) - membership.register_servlets(hs, self) federation.register_servlets(hs, self) - login.register_servlets(hs, self) - register.register_servlets(hs, self) + presence.register_servlets(hs, self) + membership.register_servlets(hs, self) + + # The following can't currently be instantiated on workers. + if hs.config.worker.worker_app is None: + login.register_servlets(hs, self) + register.register_servlets(hs, self) + devices.register_servlets(hs, self) + streams.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 03560c1f0e..793cef6c26 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,8 @@ import abc import logging import re +from inspect import signature +from typing import Dict, List, Tuple from six import raise_from from six.moves import urllib @@ -43,7 +45,7 @@ class ReplicationEndpoint(object): """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` - (with an `/:txn_id` prefix for cached requests.), where NAME is a name, + (with a `/:txn_id` suffix for cached requests), where NAME is a name, PATH_ARGS are a tuple of parameters to be encoded in the URL. For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`, @@ -59,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -78,9 +82,8 @@ class ReplicationEndpoint(object): __metaclass__ = abc.ABCMeta - NAME = abc.abstractproperty() - PATH_ARGS = abc.abstractproperty() - + NAME = abc.abstractproperty() # type: str # type: ignore + PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore METHOD = "POST" CACHE = True RETRY_ON_TIMEOUT = True @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -110,14 +123,14 @@ class ReplicationEndpoint(object): return {} @abc.abstractmethod - def _handle_request(self, request, **kwargs): + async def _handle_request(self, request, **kwargs): """Handle incoming request. This is called with the request object and PATH_ARGS. Returns: - Deferred[dict]: A JSON serialisable dict to be used as response - body of request. + tuple[int, dict]: HTTP status code and a JSON serialisable dict + to be used as response body of request. """ pass @@ -128,14 +141,30 @@ class ReplicationEndpoint(object): Returns a callable that accepts the same parameters as `_serialize_payload`. """ clock = hs.get_clock() - host = hs.config.worker_replication_host - port = hs.config.worker_replication_http_port - client = hs.get_simple_http_client() + local_instance_name = hs.get_instance_name() + + master_host = hs.config.worker_replication_host + master_port = hs.config.worker_replication_http_port + + instance_map = hs.config.worker.instance_map @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + if instance_name == local_instance_name: + raise Exception("Trying to send HTTP request to self") + if instance_name == "master": + host = master_host + port = master_port + elif instance_name in instance_map: + host = instance_map[instance_name].host + port = instance_map[instance_name].port + else: + raise Exception( + "Instance %r not in 'instance_map' config" % (instance_name,) + ) + data = yield cls._serialize_payload(**kwargs) url_args = [ @@ -171,7 +200,7 @@ class ReplicationEndpoint(object): # have a good idea that the request has either succeeded or failed on # the master, and so whether we should clean up or not. while True: - headers = {} + headers = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = yield request_func(uri, data, headers=headers) @@ -180,7 +209,7 @@ class ReplicationEndpoint(object): if e.code != 504 or not cls.RETRY_ON_TIMEOUT: raise - logger.warn("%s request timed out", cls.NAME) + logger.warning("%s request timed out", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. @@ -207,7 +236,7 @@ class ReplicationEndpoint(object): method = self.METHOD if self.CACHE: - handler = self._cached_handler + handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py new file mode 100644 index 0000000000..e32aac0a25 --- /dev/null +++ b/synapse/replication/http/devices.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): + """Ask master to resync the device list for a user by contacting their + server. + + This must happen on master so that the results can be correctly cached in + the database and streamed to workers. + + Request format: + + POST /_synapse/replication/user_device_resync/:user_id + + {} + + Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id` + response, e.g.: + + { + "user_id": "@alice:example.org", + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": { ... }, + "device_display_name": "Alice's Mobile Phone" + } + ] + } + """ + + NAME = "user_device_resync" + PATH_ARGS = ("user_id",) + CACHE = False + + def __init__(self, hs): + super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs) + + self.device_list_updater = hs.get_device_handler().device_list_updater + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @staticmethod + def _serialize_payload(user_id): + return {} + + async def _handle_request(self, request, user_id): + user_devices = await self.device_list_updater.user_device_resync(user_id) + + return 200, user_devices + + +def register_servlets(hs, http_server): + ReplicationUserDevicesResyncRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 2f16955954..c287c4e269 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -17,7 +17,8 @@ import logging from twisted.internet import defer -from synapse.events import event_type_from_format_version +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -28,7 +29,7 @@ logger = logging.getLogger(__name__) class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """Handles events newly received from federation, including persisting and - notifying. + notifying. Returns the maximum stream ID of the persisted events. The API looks like: @@ -37,11 +38,21 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): { "events": [{ "event": { .. serialized event .. }, + "room_version": .., // "1", "2", "3", etc: the version of the room + // containing the event + "event_format_version": .., // 1,2,3 etc: the event format version "internal_metadata": { .. serialized internal_metadata .. }, "rejected_reason": .., // The event.rejected_reason field "context": { .. serialized event context .. }, }], "backfilled": false + } + + 200 OK + + { + "max_stream_id": 32443, + } """ NAME = "fed_send_events" @@ -51,6 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_handler = hs.get_handlers().federation_handler @@ -71,6 +83,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event_payloads.append( { "event": event.get_pdu_json(), + "room_version": event.room_version.identifier, "event_format_version": event.format_version, "internal_metadata": event.internal_metadata.get_dict(), "rejected_reason": event.rejected_reason, @@ -82,8 +95,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request): + async def _handle_request(self, request): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) @@ -94,26 +106,27 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event_and_contexts = [] for event_payload in event_payloads: event_dict = event_payload["event"] - format_ver = event_payload["event_format_version"] + room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]] internal_metadata = event_payload["internal_metadata"] rejected_reason = event_payload["rejected_reason"] - EventType = event_type_from_format_version(format_ver) - event = EventType(event_dict, internal_metadata, rejected_reason) + event = make_event_from_dict( + event_dict, room_ver, internal_metadata, rejected_reason + ) - context = yield EventContext.deserialize( - self.store, event_payload["context"] + context = EventContext.deserialize( + self.storage, event_payload["context"] ) event_and_contexts.append((event, context)) logger.info("Got %d events from federation", len(event_and_contexts)) - yield self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) - return 200, {} + return 200, {"max_stream_id": max_stream_id} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): @@ -144,8 +157,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): def _serialize_payload(edu_type, origin, content): return {"origin": origin, "content": content} - @defer.inlineCallbacks - def _handle_request(self, request, edu_type): + async def _handle_request(self, request, edu_type): with Measure(self.clock, "repl_fed_send_edu_parse"): content = parse_json_object_from_request(request) @@ -154,7 +166,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): logger.info("Got %r edu from %s", edu_type, origin) - result = yield self.registry.on_edu(edu_type, origin, edu_content) + result = await self.registry.on_edu(edu_type, origin, edu_content) return 200, result @@ -193,8 +205,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): """ return {"args": args} - @defer.inlineCallbacks - def _handle_request(self, request, query_type): + async def _handle_request(self, request, query_type): with Measure(self.clock, "repl_fed_query_parse"): content = parse_json_object_from_request(request) @@ -202,7 +213,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): logger.info("Got %r query", query_type) - result = yield self.registry.on_query(query_type, args) + result = await self.registry.on_query(query_type, args) return 200, result @@ -213,7 +224,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): Request format: - POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id + POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id {} """ @@ -234,10 +245,41 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): """ return {} - @defer.inlineCallbacks - def _handle_request(self, request, room_id): - yield self.store.clean_room_for_join(room_id) + async def _handle_request(self, request, room_id): + await self.store.clean_room_for_join(room_id) + + return 200, {} + + +class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Request format: + + POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id + + { + "room_version": "1", + } + """ + + NAME = "store_room_on_invite" + PATH_ARGS = ("room_id",) + + def __init__(self, hs): + super().__init__(hs) + + self.store = hs.get_datastore() + + @staticmethod + def _serialize_payload(room_id, room_version): + return {"room_version": room_version.identifier} + async def _handle_request(self, request, room_id): + content = parse_json_object_from_request(request) + room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] + await self.store.maybe_store_room_on_invite(room_id, room_version) return 200, {} @@ -246,3 +288,4 @@ def register_servlets(hs, http_server): ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationCleanRoomRestServlet(hs).register(http_server) + ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 786f5232b2..798b9d3af5 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest ) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index b9ce3477ad..a7174c4a8f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -14,14 +14,16 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -65,8 +67,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -79,11 +80,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - yield self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) - return 200, {} + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -96,6 +97,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): { "requester": ..., "remote_room_hosts": [...], + "content": { ... } } """ @@ -108,9 +110,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() + self.member_handler = hs.get_room_member_handler() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -121,13 +124,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): return { "requester": requester.serialize(), "remote_room_hosts": remote_room_hosts, + "content": content, } - @defer.inlineCallbacks - def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] + event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) @@ -137,10 +141,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = yield self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, user_id + event, stream_id = await self.federation_handler.do_remotely_reject_invite( + remote_room_hosts, room_id, user_id, event_content, ) - ret = event.get_pdu_json() + event_id = event.event_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -148,12 +152,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # The 'except' clause is very broad, but we need to # capture everything from DNS failures upwards # - logger.warn("Failed to reject invite: %s", e) + logger.warning("Failed to reject invite: %s", e) + + stream_id = await self.member_handler.locally_reject_invite( + user_id, room_id + ) + event_id = None + + return 200, {"event_id": event_id, "stream_id": stream_id} + + +class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint): + """Rejects the invite for the user and room locally. + + Request format: + + POST /_synapse/replication/locally_reject_invite/:room_id/:user_id + + {} + """ + + NAME = "locally_reject_invite" + PATH_ARGS = ("room_id", "user_id") + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.member_handler = hs.get_room_member_handler() + + @staticmethod + def _serialize_payload(room_id, user_id): + return {} + + async def _handle_request(self, request, room_id, user_id): + logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id) - yield self.store.locally_reject_invite(user_id, room_id) - ret = {} + stream_id = await self.member_handler.locally_reject_invite(user_id, room_id) - return 200, ret + return 200, {"stream_id": stream_id} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): @@ -209,3 +245,4 @@ def register_servlets(hs, http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) + ReplicationLocallyRejectInviteRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py new file mode 100644 index 0000000000..ea1b33331b --- /dev/null +++ b/synapse/replication/http/presence.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): + """We've seen the user do something that indicates they're interacting + with the app. + + The POST looks like: + + POST /_synapse/replication/bump_presence_active_time/<user_id> + + 200 OK + + {} + """ + + NAME = "bump_presence_active_time" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id): + return {} + + async def _handle_request(self, request, user_id): + await self._presence_handler.bump_presence_active_time( + UserID.from_string(user_id) + ) + + return ( + 200, + {}, + ) + + +class ReplicationPresenceSetState(ReplicationEndpoint): + """Set the presence state for a user. + + The POST looks like: + + POST /_synapse/replication/presence_set_state/<user_id> + + { + "state": { ... }, + "ignore_status_msg": false, + } + + 200 OK + + {} + """ + + NAME = "presence_set_state" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id, state, ignore_status_msg=False): + return { + "state": state, + "ignore_status_msg": ignore_status_msg, + } + + async def _handle_request(self, request, user_id): + content = parse_json_object_from_request(request) + + await self._presence_handler.set_state( + UserID.from_string(user_id), content["state"], content["ignore_status_msg"] + ) + + return ( + 200, + {}, + ) + + +def register_servlets(hs, http_server): + ReplicationBumpPresenceActiveTime(hs).register(http_server) + ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 38260256cf..0c4aca1291 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -74,11 +72,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "address": address, } - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.registration_handler.register_with_store( + self.registration_handler.check_registration_ratelimit(content["address"]) + + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], was_guest=content["was_guest"], @@ -117,14 +116,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): """ return {"auth_result": auth_result, "access_token": access_token} - @defer.inlineCallbacks - def _handle_request(self, request, user_id): + async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) auth_result = content["auth_result"] access_token = content["access_token"] - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=user_id, auth_result=auth_result, access_token=access_token ) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index adb9b2f7f4..c981723c1a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -17,7 +17,8 @@ import logging from twisted.internet import defer -from synapse.events import event_type_from_format_version +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -37,6 +38,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): { "event": { .. serialized event .. }, + "room_version": .., // "1", "2", "3", etc: the version of the room + // containing the event + "event_format_version": .., // 1,2,3 etc: the event format version "internal_metadata": { .. serialized internal_metadata .. }, "rejected_reason": .., // The event.rejected_reason field "context": { .. serialized event context .. }, @@ -54,6 +58,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() @staticmethod @@ -76,6 +81,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): payload = { "event": event.get_pdu_json(), + "room_version": event.room_version.identifier, "event_format_version": event.format_version, "internal_metadata": event.internal_metadata.get_dict(), "rejected_reason": event.rejected_reason, @@ -87,21 +93,21 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): return payload - @defer.inlineCallbacks - def _handle_request(self, request, event_id): + async def _handle_request(self, request, event_id): with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) event_dict = content["event"] - format_ver = content["event_format_version"] + room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]] internal_metadata = content["internal_metadata"] rejected_reason = content["rejected_reason"] - EventType = event_type_from_format_version(format_ver) - event = EventType(event_dict, internal_metadata, rejected_reason) + event = make_event_from_dict( + event_dict, room_ver, internal_metadata, rejected_reason + ) requester = Requester.deserialize(self.store, content["requester"]) - context = yield EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.storage, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] @@ -113,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - yield self.event_creation_handler.persist_and_notify_client_event( + stream_id = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {} + return 200, {"stream_id": stream_id} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py new file mode 100644 index 0000000000..bde97eef32 --- /dev/null +++ b/synapse/replication/http/streams.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationGetStreamUpdates(ReplicationEndpoint): + """Fetches stream updates from a server. Used for streams not persisted to + the database, e.g. typing notifications. + + The API looks like: + + GET /_synapse/replication/get_repl_stream_updates/<stream name>?from_token=0&to_token=10 + + 200 OK + + { + updates: [ ... ], + upto_token: 10, + limited: False, + } + + If there are more rows than can sensibly be returned in one lump, `limited` will be + set to true, and the caller should call again with a new `from_token`. + + """ + + NAME = "get_repl_stream_updates" + PATH_ARGS = ("stream_name",) + METHOD = "GET" + + def __init__(self, hs): + super().__init__(hs) + + self._instance_name = hs.get_instance_name() + self.streams = hs.get_replication_streams() + + @staticmethod + def _serialize_payload(stream_name, from_token, upto_token): + return {"from_token": from_token, "upto_token": upto_token} + + async def _handle_request(self, request, stream_name): + stream = self.streams.get(stream_name) + if stream is None: + raise SynapseError(400, "Unknown stream") + + from_token = parse_integer(request, "from_token", required=True) + upto_token = parse_integer(request, "upto_token", required=True) + + updates, upto_token, limited = await stream.get_updates_since( + self._instance_name, from_token, upto_token + ) + + return ( + 200, + {"updates": updates, "upto_token": upto_token, "limited": limited}, + ) + + +def register_servlets(hs, http_server): + ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 182cb2a1d8..f9e2533e96 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,56 +14,30 @@ # limitations under the License. import logging +from typing import Optional -import six - -from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine - -from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage.util.id_generators import MultiWriterIdGenerator logger = logging.getLogger(__name__) -def __func__(inp): - if six.PY3: - return inp - else: - return inp.__func__ - - -class BaseSlavedStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(db_conn, hs) +class BaseSlavedStore(CacheInvalidationWorkerStore): + def __init__(self, database: Database, db_conn, hs): + super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id" - ) + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name=hs.get_instance_name(), + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", + ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None self.hs = hs - - def stream_positions(self): - pos = {} - if self._cache_id_gen: - pos["caches"] = self._cache_id_gen.get_current_token() - return pos - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == "caches": - self._cache_id_gen.advance(token) - for row in rows: - if row.cache_func == _CURRENT_STATE_CACHE_NAME: - room_id = row.keys[0] - members_changed = set(row.keys[1:]) - self._invalidate_state_caches(room_id, members_changed) - else: - self._attempt_to_invalidate_cache(row.cache_func, tuple(row.keys)) - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - txn.call_after(cache_func.invalidate, keys) - txn.call_after(self._send_invalidation_poke, cache_func, keys) - - def _send_invalidation_poke(self, cache_func, keys): - self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 3c44d1d48d..9db6c62bc7 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -16,30 +16,29 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.account_data import AccountDataWorkerStore -from synapse.storage.tags import TagsWorkerStore +from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore +from synapse.storage.data_stores.main.tags import TagsWorkerStore +from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( - db_conn, "account_data_max_stream_id", "stream_id" + db_conn, + "account_data", + "stream_id", + extra_tables=[ + ("room_account_data", "stream_id"), + ("room_tags_revisions", "stream_id"), + ], ) - super(SlavedAccountDataStore, self).__init__(db_conn, hs) + super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedAccountDataStore, self).stream_positions() - position = self._account_data_id_gen.get_current_token() - result["user_account_data"] = position - result["room_account_data"] = position - result["tag_account_data"] = position - return result - - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) for row in rows: @@ -58,6 +57,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedAccountDataStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index cda12ea70d..a67fbeffb7 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 14ced32333..1a38f53dfb 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -13,19 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.database import Database from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedClientIpStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 284fd30d89..6e7fd259d4 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -15,14 +15,15 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) @@ -42,12 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def stream_positions(self): - result = super(SlavedDeviceInboxStore, self).stream_positions() - result["to_device"] = self._device_inbox_id_gen.get_current_token() - return result - - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) for row in rows: @@ -59,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - return super(SlavedDeviceInboxStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index d9300fce33..9d8067342f 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -15,50 +15,61 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.devices import DeviceWorkerStore -from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream +from synapse.storage.data_stores.main.devices import DeviceWorkerStore +from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id" + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", device_list_max + ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max ) - def stream_positions(self): - result = super(SlavedDeviceStore, self).stream_positions() - result["device_lists"] = self._device_list_id_gen.get_current_token() - return result - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == "device_lists": + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(token) + self._invalidate_caches_for_devices(token, rows) + elif stream_name == UserSignatureStream.NAME: self._device_list_id_gen.advance(token) for row in rows: - self._invalidate_caches_for_devices(token, row.user_id, row.destination) - return super(SlavedDeviceStore, self).process_replication_rows( - stream_name, token, rows - ) - - def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed(user_id, token) + self._user_signature_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) - if destination: - self._device_list_federation_stream_cache.entity_has_changed( - destination, token - ) + def _invalidate_caches_for_devices(self, token, rows): + for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate_many((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - self._get_cached_devices_for_user.invalidate((user_id,)) - self._get_cached_user_device.invalidate_many((user_id,)) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py index 1d1d48709a..8b9717c46f 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.directory import DirectoryWorkerStore +from synapse.storage.data_stores.main.directory import DirectoryWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index ab5937e638..1a1a50a24f 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,23 +15,21 @@ # limitations under the License. import logging -from synapse.api.constants import EventTypes -from synapse.replication.tcp.streams.events import ( - EventsStreamCurrentStateRow, - EventsStreamEventRow, +from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore +from synapse.storage.data_stores.main.event_push_actions import ( + EventPushActionsWorkerStore, ) -from synapse.storage.event_federation import EventFederationWorkerStore -from synapse.storage.event_push_actions import EventPushActionsWorkerStore -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.relations import RelationsWorkerStore -from synapse.storage.roommember import RoomMemberWorkerStore -from synapse.storage.signatures import SignatureWorkerStore -from synapse.storage.state import StateGroupWorkerStore -from synapse.storage.stream import StreamWorkerStore -from synapse.storage.user_erasure_store import UserErasureWorkerStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.relations import RelationsWorkerStore +from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.data_stores.main.stream import StreamWorkerStore +from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore +from synapse.storage.database import Database +from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -57,13 +55,23 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 + def __init__(self, database: Database, db_conn, hs): + super(SlavedEventStore, self).__init__(database, db_conn, hs) + + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, ) - - super(SlavedEventStore, self).__init__(db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. @@ -73,85 +81,3 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - - def stream_positions(self): - result = super(SlavedEventStore, self).stream_positions() - result["events"] = self._stream_id_gen.get_current_token() - result["backfill"] = -self._backfill_id_gen.get_current_token() - return result - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == "events": - self._stream_id_gen.advance(token) - for row in rows: - self._process_event_stream_row(token, row) - elif stream_name == "backfill": - self._backfill_id_gen.advance(-token) - for row in rows: - self.invalidate_caches_for_event( - -token, - row.event_id, - row.room_id, - row.type, - row.state_key, - row.redacts, - row.relates_to, - backfilled=True, - ) - return super(SlavedEventStore, self).process_replication_rows( - stream_name, token, rows - ) - - def _process_event_stream_row(self, token, row): - data = row.data - - if row.type == EventsStreamEventRow.TypeId: - self.invalidate_caches_for_event( - token, - data.event_id, - data.room_id, - data.type, - data.state_key, - data.redacts, - data.relates_to, - backfilled=False, - ) - elif row.type == EventsStreamCurrentStateRow.TypeId: - if data.type == EventTypes.Member: - self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key,) - ) - else: - raise Exception("Unknown events stream row type %s" % (row.type,)) - - def invalidate_caches_for_event( - self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): - self._invalidate_get_event_cache(event_id) - - self.get_latest_event_ids_in_room.invalidate((room_id,)) - - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) - - if not backfilled: - self._events_stream_cache.entity_has_changed(room_id, stream_ordering) - - if redacts: - self._invalidate_get_event_cache(redacts) - - if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_user.invalidate((state_key,)) - - if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 456a14cd5c..bcb0688954 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.filtering import FilteringStore +from synapse.storage.data_stores.main.filtering import FilteringStore +from synapse.storage.database import Database from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedFilteringStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd28..1851e7d525 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) +class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs @@ -34,21 +34,13 @@ class SlavedGroupServerStore(BaseSlavedStore): self._group_updates_id_gen.get_current_token(), ) - get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) - get_group_stream_token = __func__(DataStore.get_group_stream_token) - get_all_groups_for_user = __func__(DataStore.get_all_groups_for_user) - - def stream_positions(self): - result = super(SlavedGroupServerStore, self).stream_positions() - result["groups"] = self._group_updates_id_gen.get_current_token() - return result + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedGroupServerStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py index cc6f7f009f..3def367ae9 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import KeyStore +from synapse.storage.data_stores.main.keys import KeyStore # KeyStore isn't really safe to use from a worker, but for now we do so and hope that # the races it creates aren't too bad. diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 82d808af4c..4e0124842d 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -14,47 +14,37 @@ # limitations under the License. from synapse.storage import DataStore -from synapse.storage.presence import PresenceStore +from synapse.storage.data_stores.main.presence import PresenceStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ +from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPresenceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") - self._presence_on_startup = self._get_active_presence(db_conn) + self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore - self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache( + self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() ) - _get_active_presence = __func__(DataStore._get_active_presence) - take_presence_startup_info = __func__(DataStore.take_presence_startup_info) + _get_active_presence = DataStore._get_active_presence + take_presence_startup_info = DataStore.take_presence_startup_info _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPresenceStore, self).stream_positions() - - if self.hs.config.use_presence: - position = self._presence_id_gen.get_current_token() - result["presence"] = position - - return result - - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) - return super(SlavedPresenceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py index 46c28d4171..28c508aad3 100644 --- a/synapse/replication/slave/storage/profile.py +++ b/synapse/replication/slave/storage/profile.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.storage.profile import ProfileWorkerStore +from synapse.storage.data_stores.main.profile import ProfileWorkerStore class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index af7012702e..6adb19463a 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -14,19 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.push_rule import PushRulesWorkerStore +from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore -from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, db_conn, hs): - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) - super(SlavedPushRuleStore, self).__init__(db_conn, hs) - def get_push_rules_stream_token(self): return ( self._push_rules_stream_id_gen.get_current_token(), @@ -36,18 +29,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPushRuleStore, self).stream_positions() - result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() - return result - - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedPushRuleStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 8eeb267d61..cb78b49acb 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -14,27 +14,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPusherStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def stream_positions(self): - result = super(SlavedPusherStore, self).stream_positions() - result["pushers"] = self._pushers_id_gen.get_current_token() - return result + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) - return super(SlavedPusherStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 91afa5a72b..be716cc558 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -29,23 +30,18 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(db_conn, hs) + super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedReceiptsStore, self).stream_positions() - result["receipts"] = self._receipts_id_gen.get_current_token() - return result - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate_many((room_id,)) @@ -55,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "receipts": self._receipts_id_gen.advance(token) for row in rows: @@ -64,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - return super(SlavedReceiptsStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py index 408d91df1c..4b8553e250 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.registration import RegistrationWorkerStore +from synapse.storage.data_stores.main.registration import RegistrationWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index f68b3378e3..8873bf37e5 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.room import RoomWorkerStore +from synapse.storage.data_stores.main.room import RoomWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) @@ -29,13 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def stream_positions(self): - result = super(RoomStore, self).stream_positions() - result["public_rooms"] = self._public_room_id_gen.get_current_token() - return result - - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows(stream_name, token, rows) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py index 3527beb3c9..ac88e6b8c3 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.transactions import TransactionStore +from synapse.storage.data_stores.main.transactions import TransactionStore from ._base import BaseSlavedStore diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 81c2ea7ee9..523a1358d4 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst Structure of the module: - * client.py - the client classes used for workers to connect to master + * handler.py - the classes used to handle sending/receiving commands to + replication * command.py - the definitions of all the valid commands - * protocol.py - contains bot the client and server protocol implementations, - these should not be used directly - * resource.py - the server classes that accepts and handle client connections - * streams.py - the definitons of all the valid streams + * protocol.py - the TCP protocol classes + * resource.py - handles streaming stream updates to replications + * streams/ - the definitons of all the valid streams + +The general interaction of the classes are: + + +---------------------+ + | ReplicationStreamer | + +---------------------+ + | + v + +---------------------------+ +----------------------+ + | ReplicationCommandHandler |---->|ReplicationDataHandler| + +---------------------------+ +----------------------+ + | ^ + v | + +-------------+ + | Protocols | + | (TCP/redis) | + +-------------+ + +Where the ReplicationDataHandler (or subclasses) handles incoming stream +updates. """ diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index a44ceb00e7..df29732f51 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,38 +14,55 @@ # limitations under the License. """A replication client for use by synapse workers. """ - +import heapq import logging +from typing import TYPE_CHECKING, Dict, List, Tuple -from twisted.internet import defer +from twisted.internet.defer import Deferred from twisted.internet.protocol import ReconnectingClientFactory -from .commands import ( - FederationAckCommand, - InvalidateCacheCommand, - RemovePusherCommand, - UserIpCommand, - UserSyncCommand, +from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamEventRow, + EventsStreamRow, ) -from .protocol import ClientReplicationStreamProtocol +from synapse.util.async_helpers import timeout_deferred +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.replication.tcp.handler import ReplicationCommandHandler logger = logging.getLogger(__name__) -class ReplicationClientFactory(ReconnectingClientFactory): +# How long we allow callers to wait for replication updates before timing out. +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 + + +class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. - Accepts a handler that will be called when new data is available or data - is required. + Accepts a handler that is passed to `ClientReplicationStreamProtocol`. """ - maxDelay = 30 # Try at least once every N seconds + initialDelay = 0.1 + maxDelay = 1 # Try at least once every N seconds - def __init__(self, hs, client_name, handler): + def __init__( + self, + hs: "HomeServer", + client_name: str, + command_handler: "ReplicationCommandHandler", + ): self.client_name = client_name - self.handler = handler + self.command_handler = command_handler self.server_name = hs.config.server_name + self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) @@ -56,7 +73,11 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.client_name, self.server_name, self._clock, self.handler + self.hs, + self.client_name, + self.server_name, + self._clock, + self.command_handler, ) def clientConnectionLost(self, connector, reason): @@ -68,162 +89,136 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(object): - """A base handler that can be passed to the ReplicationClientFactory. +class ReplicationDataHandler: + """Handles incoming stream updates from replication. - By default proxies incoming replication data to the SlaveStore. + This instance notifies the slave data store about updates. Can be subclassed + to handle updates in additional ways. """ - def __init__(self, store): - self.store = store - - # The current connection. None if we are currently (re)connecting - self.connection = None - - # Any pending commands to be sent once a new connection has been - # established - self.pending_commands = [] - - # Map from string -> deferred, to wake up when receiveing a SYNC with - # the given string. - # Used for tests. - self.awaiting_syncs = {} - - # The factory used to create connections. - self.factory = None - - def start_replication(self, hs): - """Helper method to start a replication connection to the remote server - using TCP. - """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) - - def on_rdata(self, stream_name, token, rows): + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.pusher_pool = hs.get_pusherpool() + self.notifier = hs.get_notifier() + self._reactor = hs.get_reactor() + self._clock = hs.get_clock() + self._streams = hs.get_replication_streams() + self._instance_name = hs.get_instance_name() + + # Map from stream to list of deferreds waiting for the stream to + # arrive at a particular position. The lists are sorted by stream position. + self._streams_to_waiters = ( + {} + ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] + + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to handle more. Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - - Returns: - Deferred|None - """ - logger.debug("Received rdata %s -> %s", stream_name, token) - return self.store.process_replication_rows(stream_name, token, rows) - - def on_position(self, stream_name, token): - """Called when we get new position data. By default this just pokes - the slave store. - - Can be overriden in subclasses to handle more. - """ - return self.store.process_replication_rows(stream_name, token, []) - - def on_sync(self, data): - """When we received a SYNC we wake up any deferreds that were waiting - for the sync with the given data. - - Used by tests. + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - d = self.awaiting_syncs.pop(data, None) - if d: - d.callback(data) - - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns a dictionary of stream name to token. - """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data - - return args - - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) - """ - return [] - - def send_command(self, cmd): - """Send a command to master (when we get establish a connection if we - don't have one already.) - """ - if self.connection: - self.connection.send_command(cmd) - else: - logger.warn("Queuing command as not connected: %r", cmd.NAME) - self.pending_commands.append(cmd) - - def send_federation_ack(self, token): - """Ack data for the federation stream. This allows the master to drop - data stored purely in memory. - """ - self.send_command(FederationAckCommand(token)) - - def send_user_sync(self, user_id, is_syncing, last_sync_ms): - """Poke the master that a user has started/stopped syncing. - """ - self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) - - def send_remove_pusher(self, app_id, push_key, user_id): - """Poke the master to remove a pusher for a user - """ - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - - def send_invalidate_cache(self, cache_func, keys): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - - def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """Tell the master that the user made a request. + self.store.process_replication_rows(stream_name, instance_name, token, rows) + + if stream_name == EventsStream.NAME: + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + if row.type != EventsStreamEventRow.TypeId: + continue + assert isinstance(row, EventsStreamRow) + + event = await self.store.get_event( + row.data.event_id, allow_rejected=True + ) + if event.rejected_reason: + continue + + extra_users = () # type: Tuple[str, ...] + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event(event, token, max_token, extra_users) + + await self.pusher_pool.on_new_notifications(token, token) + + # Notify any waiting deferreds. The list is ordered by position so we + # just iterate through the list until we reach a position that is + # greater than the received row position. + waiting_list = self._streams_to_waiters.get(stream_name, []) + + # Index of first item with a position after the current token, i.e we + # have called all deferreds before this index. If not overwritten by + # loop below means either a) no items in list so no-op or b) all items + # in list were called and so the list should be cleared. Setting it to + # `len(list)` works for both cases. + index_of_first_deferred_not_called = len(waiting_list) + + for idx, (position, deferred) in enumerate(waiting_list): + if position <= token: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + else: + # The list is sorted by position so we don't need to continue + # checking any further entries in the list. + index_of_first_deferred_not_called = idx + break + + # Drop all entries in the waiting list that were called in the above + # loop. (This maintains the order so no need to resort) + waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + + async def on_position(self, stream_name: str, instance_name: str, token: int): + self.store.process_replication_rows(stream_name, instance_name, token, []) + + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + async def wait_for_stream_position( + self, instance_name: str, stream_name: str, position: int + ): + """Wait until this instance has received updates up to and including + the given stream position. """ - cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) - self.send_command(cmd) - def await_sync(self, data): - """Returns a deferred that is resolved when we receive a SYNC command - with given data. + if instance_name == self._instance_name: + # We don't get told about updates written by this process, and + # anyway in that case we don't need to wait. + return + + current_position = self._streams[stream_name].current_token(self._instance_name) + if position <= current_position: + # We're already past the position + return + + # Create a new deferred that times out after N seconds, as we don't want + # to wedge here forever. + deferred = Deferred() + deferred = timeout_deferred( + deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + ) - [Not currently] used by tests. - """ - return self.awaiting_syncs.setdefault(data, defer.Deferred()) + waiting_list = self._streams_to_waiters.setdefault(stream_name, []) - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - self.connection = connection - if connection: - for cmd in self.pending_commands: - connection.send_command(cmd) - self.pending_commands = [] - - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - logger.info("Finished connecting to server") + # We insert into the list using heapq as it is more efficient than + # pushing then resorting each time. + heapq.heappush(waiting_list, (position, deferred)) - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - self.factory.resetDelay() + # We measure here to get in flight counts and average waiting time. + with Measure(self._clock, "repl.wait_for_stream_position"): + logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + await make_deferred_yieldable(deferred) + logger.info( + "Finished waiting for repl stream %r to reach %s", stream_name, position + ) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 0ff2a7199f..c04f622816 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -17,50 +17,46 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ - +import abc import logging import platform +from typing import Tuple, Type if platform.python_implementation() == "PyPy": import json _json_encoder = json.JSONEncoder() else: - import simplejson as json + import simplejson as json # type: ignore[no-redef] # noqa: F821 - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 logger = logging.getLogger(__name__) -class Command(object): +class Command(metaclass=abc.ABCMeta): """The base command class. All subclasses must set the NAME variable which equates to the name of the command on the wire. A full command line on the wire is constructed from `NAME + " " + to_line()` - - The default implementation creates a command of form `<NAME> <data>` """ - NAME = None - - def __init__(self, data): - self.data = data + NAME = None # type: str @classmethod + @abc.abstractmethod def from_line(cls, line): """Deserialises a line from the wire into this command. `line` does not include the command. """ - return cls(line) - def to_line(self): + @abc.abstractmethod + def to_line(self) -> str: """Serialises the comamnd for the wire. Does not include the command prefix. """ - return self.data def get_logcontext_id(self): """Get a suitable string for the logcontext when processing this command""" @@ -69,7 +65,21 @@ class Command(object): return self.NAME -class ServerCommand(Command): +class _SimpleCommand(Command): + """An implementation of Command whose argument is just a 'data' string.""" + + def __init__(self, data): + self.data = data + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self) -> str: + return self.data + + +class ServerCommand(_SimpleCommand): """Sent by the server on new connection and includes the server_name. Format:: @@ -85,7 +95,7 @@ class RdataCommand(Command): Format:: - RDATA <stream_name> <token> <row_json> + RDATA <stream_name> <instance_name> <token> <row_json> The `<token>` may either be a numeric stream id OR "batch". The latter case is used to support sending multiple updates with the same stream ID. This @@ -95,33 +105,40 @@ class RdataCommand(Command): The client should batch all incoming RDATA with a token of "batch" (per stream_name) until it sees an RDATA with a numeric stream ID. + The `<instance_name>` is the source of the new data (usually "master"). + `<token>` of "batch" maps to the instance variable `token` being None. An example of a batched series of RDATA:: - RDATA presence batch ["@foo:example.com", "online", ...] - RDATA presence batch ["@bar:example.com", "online", ...] - RDATA presence 59 ["@baz:example.com", "online", ...] + RDATA presence master batch ["@foo:example.com", "online", ...] + RDATA presence master batch ["@bar:example.com", "online", ...] + RDATA presence master 59 ["@baz:example.com", "online", ...] """ NAME = "RDATA" - def __init__(self, stream_name, token, row): + def __init__(self, stream_name, instance_name, token, row): self.stream_name = stream_name + self.instance_name = instance_name self.token = token self.row = row @classmethod def from_line(cls, line): - stream_name, token, row_json = line.split(" ", 2) + stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( - stream_name, None if token == "batch" else int(token), json.loads(row_json) + stream_name, + instance_name, + None if token == "batch" else int(token), + json.loads(row_json), ) def to_line(self): return " ".join( ( self.stream_name, + self.instance_name, str(self.token) if self.token is not None else "batch", _json_encoder.encode(self.row), ) @@ -135,26 +152,34 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. - Sent to the client after all missing updates for a stream have been sent - to the client and they're now up to date. + Format:: + + POSITION <stream_name> <instance_name> <token> + + On receipt of a POSITION command clients should check if they have missed + any updates, and if so then fetch them out of band. + + The `<instance_name>` is the process that sent the command and is the source + of the stream. """ NAME = "POSITION" - def __init__(self, stream_name, token): + def __init__(self, stream_name, instance_name, token): self.stream_name = stream_name + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - return cls(stream_name, int(token)) + stream_name, instance_name, token = line.split(" ", 2) + return cls(stream_name, instance_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token))) + return " ".join((self.stream_name, self.instance_name, str(self.token))) -class ErrorCommand(Command): +class ErrorCommand(_SimpleCommand): """Sent by either side if there was an ERROR. The data is a string describing the error. """ @@ -162,14 +187,14 @@ class ErrorCommand(Command): NAME = "ERROR" -class PingCommand(Command): +class PingCommand(_SimpleCommand): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ NAME = "PING" -class NameCommand(Command): +class NameCommand(_SimpleCommand): """Sent by client to inform the server of the client's identity. The data is the name """ @@ -178,76 +203,63 @@ class NameCommand(Command): class ReplicateCommand(Command): - """Sent by the client to subscribe to the stream. + """Sent by the client to subscribe to streams. Format:: - REPLICATE <stream_name> <token> - - Where <token> may be either: - * a numeric stream_id to stream updates from - * "NOW" to stream all subsequent updates. - - The <stream_name> can be "ALL" to subscribe to all known streams, in which - case the <token> must be set to "NOW", i.e.:: - - REPLICATE ALL NOW + REPLICATE """ NAME = "REPLICATE" - def __init__(self, stream_name, token): - self.stream_name = stream_name - self.token = token + def __init__(self): + pass @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - if token in ("NOW", "now"): - token = "NOW" - else: - token = int(token) - return cls(stream_name, token) + return cls() def to_line(self): - return " ".join((self.stream_name, str(self.token))) - - def get_logcontext_id(self): - return "REPLICATE-" + self.stream_name + return "" class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or - stopped syncing. Used to calculate presence on the master. + stopped syncing on this process. + + This is used by the process handling presence (typically the master) to + calculate who is online and who is not. Includes a timestamp of when the last user sync was. Format:: - USER_SYNC <user_id> <state> <last_sync_ms> + USER_SYNC <instance_id> <user_id> <state> <last_sync_ms> - Where <state> is either "start" or "stop" + Where <state> is either "start" or "end" """ NAME = "USER_SYNC" - def __init__(self, user_id, is_syncing, last_sync_ms): + def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls, line): - user_id, state, last_sync_ms = line.split(" ", 2) + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self): return " ".join( ( + self.instance_id, self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), @@ -255,6 +267,30 @@ class UserSyncCommand(Command): ) +class ClearUserSyncsCommand(Command): + """Sent by the client to inform the server that it should drop all + information about syncing users sent by the client. + + Mainly used when client is about to shut down. + + Format:: + + CLEAR_USER_SYNC <instance_id> + """ + + NAME = "CLEAR_USER_SYNC" + + def __init__(self, instance_id): + self.instance_id = instance_id + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self): + return self.instance_id + + class FederationAckCommand(Command): """Sent by the client when it has processed up to a given point in the federation stream. This allows the master to drop in-memory caches of the @@ -280,14 +316,6 @@ class FederationAckCommand(Command): return str(self.token) -class SyncCommand(Command): - """Used for testing. The client protocol implementation allows waiting - on a SYNC command with a specified data. - """ - - NAME = "SYNC" - - class RemovePusherCommand(Command): """Sent by the client to request the master remove the given pusher. @@ -313,37 +341,6 @@ class RemovePusherCommand(Command): return " ".join((self.app_id, self.push_key, self.user_id)) -class InvalidateCacheCommand(Command): - """Sent by the client to invalidate an upstream cache. - - THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE - NOT DISASTROUS IF WE DROP ON THE FLOOR. - - Mainly used to invalidate destination retry timing caches. - - Format:: - - INVALIDATE_CACHE <cache_func> <keys_json> - - Where <keys_json> is a json list. - """ - - NAME = "INVALIDATE_CACHE" - - def __init__(self, cache_func, keys): - self.cache_func = cache_func - self.keys = keys - - @classmethod - def from_line(cls, line): - cache_func, keys_json = line.split(" ", 1) - - return cls(cache_func, json.loads(keys_json)) - - def to_line(self): - return " ".join((self.cache_func, _json_encoder.encode(self.keys))) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -386,25 +383,38 @@ class UserIpCommand(Command): ) +class RemoteServerUpCommand(_SimpleCommand): + """Sent when a worker has detected that a remote server is no longer + "down" and retry timings should be reset. + + If sent from a client the server will relay to all other workers. + + Format:: + + REMOTE_SERVER_UP <server> + """ + + NAME = "REMOTE_SERVER_UP" + + +_COMMANDS = ( + ServerCommand, + RdataCommand, + PositionCommand, + ErrorCommand, + PingCommand, + NameCommand, + ReplicateCommand, + UserSyncCommand, + FederationAckCommand, + RemovePusherCommand, + UserIpCommand, + RemoteServerUpCommand, + ClearUserSyncsCommand, +) # type: Tuple[Type[Command], ...] + # Map of command name to command type. -COMMAND_MAP = { - cmd.NAME: cmd - for cmd in ( - ServerCommand, - RdataCommand, - PositionCommand, - ErrorCommand, - PingCommand, - NameCommand, - ReplicateCommand, - UserSyncCommand, - FederationAckCommand, - SyncCommand, - RemovePusherCommand, - InvalidateCacheCommand, - UserIpCommand, - ) -} +COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} # The commands the server is allowed to send VALID_SERVER_COMMANDS = ( @@ -413,7 +423,7 @@ VALID_SERVER_COMMANDS = ( PositionCommand.NAME, ErrorCommand.NAME, PingCommand.NAME, - SyncCommand.NAME, + RemoteServerUpCommand.NAME, ) # The commands the client is allowed to send @@ -422,9 +432,28 @@ VALID_CLIENT_COMMANDS = ( ReplicateCommand.NAME, PingCommand.NAME, UserSyncCommand.NAME, + ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, - InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, + RemoteServerUpCommand.NAME, ) + + +def parse_command_from_line(line: str) -> Command: + """Parses a command from a received line. + + Line should already be stripped of whitespace and be checked if blank. + """ + + idx = line.find(" ") + if idx >= 0: + cmd_name = line[:idx] + rest_of_line = line[idx + 1 :] + else: + cmd_name = line + rest_of_line = "" + + cmd_cls = COMMAND_MAP[cmd_name] + return cmd_cls.from_line(rest_of_line) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..cbcf46f3ae --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,596 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar + +from prometheus_client import Counter + +from twisted.internet.protocol import ReconnectingClientFactory + +from synapse.metrics import LaterGauge +from synapse.replication.tcp.client import DirectTcpReplicationClientFactory +from synapse.replication.tcp.commands import ( + ClearUserSyncsCommand, + Command, + FederationAckCommand, + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + RemovePusherCommand, + ReplicateCommand, + UserIpCommand, + UserSyncCommand, +) +from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.streams import ( + STREAMS_MAP, + BackfillStream, + CachesStream, + EventsStream, + FederationStream, + Stream, +) +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") + + +class ReplicationCommandHandler: + """Handles incoming commands from replication as well as sending commands + back out to connections. + """ + + def __init__(self, hs): + self._replication_data_handler = hs.get_replication_data_handler() + self._presence_handler = hs.get_presence_handler() + self._store = hs.get_datastore() + self._notifier = hs.get_notifier() + self._clock = hs.get_clock() + self._instance_id = hs.get_instance_id() + self._instance_name = hs.get_instance_name() + + self._streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + # List of streams that this instance is the source of + self._streams_to_replicate = [] # type: List[Stream] + + for stream in self._streams.values(): + if stream.NAME == CachesStream.NAME: + # All workers can write to the cache invalidation stream. + self._streams_to_replicate.append(stream) + continue + + if isinstance(stream, (EventsStream, BackfillStream)): + # Only add EventStream and BackfillStream as a source on the + # instance in charge of event persistence. + if hs.config.worker.writers.events == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + + # Only add any other streams if we're on master. + if hs.config.worker_app is not None: + continue + + if stream.NAME == FederationStream.NAME and hs.config.send_federation: + # We only support federation stream if federation sending + # has been disabled on the master. + continue + + self._streams_to_replicate.append(stream) + + self._position_linearizer = Linearizer( + "replication_position", clock=self._clock + ) + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self._pending_batches = {} # type: Dict[str, List[Any]] + + # The factory used to create connections. + self._factory = None # type: Optional[ReconnectingClientFactory] + + # The currently connected connections. (The list of places we need to send + # outgoing replication commands to.) + self._connections = [] # type: List[AbstractConnection] + + # For each connection, the incoming streams that are coming from that connection + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self._connections), + ) + + self._is_master = hs.config.worker_app is None + + self._federation_sender = None + if self._is_master and not hs.config.send_federation: + self._federation_sender = hs.get_federation_sender() + + self._server_notices_sender = None + if self._is_master: + self._server_notices_sender = hs.get_server_notices_sender() + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + if hs.config.redis.redis_enabled: + from synapse.replication.tcp.redis import ( + RedisDirectTcpReplicationClientFactory, + ) + import txredisapi + + logger.info( + "Connecting to redis (host=%r port=%r)", + hs.config.redis_host, + hs.config.redis_port, + ) + + # First let's ensure that we have a ReplicationStreamer started. + hs.get_replication_streamer() + + # We need two connections to redis, one for the subscription stream and + # one to send commands to (as you can't send further redis commands to a + # connection after SUBSCRIBE is called). + + # First create the connection for sending commands. + outbound_redis_connection = txredisapi.lazyConnection( + host=hs.config.redis_host, + port=hs.config.redis_port, + password=hs.config.redis.redis_password, + reconnect=True, + ) + + # Now create the factory/connection for the subscription stream. + self._factory = RedisDirectTcpReplicationClientFactory( + hs, outbound_redis_connection + ) + hs.get_reactor().connectTCP( + hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, + ) + else: + client_name = hs.get_instance_name() + self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) + + def get_streams(self) -> Dict[str, Stream]: + """Get a map from stream name to all streams. + """ + return self._streams + + def get_streams_to_replicate(self) -> List[Stream]: + """Get a list of streams that this instances replicates. + """ + return self._streams_to_replicate + + async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + self.send_positions_to_connection(conn) + + def send_positions_to_connection(self, conn: AbstractConnection): + """Send current position of all streams this process is source of to + the connection. + """ + + # We respond with current position of all streams this instance + # replicates. + for stream in self.get_streams_to_replicate(): + self.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + stream.current_token(self._instance_name), + ) + ) + + async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): + user_sync_counter.inc() + + if self._is_master: + await self._presence_handler.update_external_syncs_row( + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + ) + + async def on_CLEAR_USER_SYNC( + self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + ): + if self._is_master: + await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + + async def on_FEDERATION_ACK( + self, conn: AbstractConnection, cmd: FederationAckCommand + ): + federation_ack_counter.inc() + + if self._federation_sender: + self._federation_sender.federation_ack(cmd.token) + + async def on_REMOVE_PUSHER( + self, conn: AbstractConnection, cmd: RemovePusherCommand + ): + remove_pusher_counter.inc() + + if self._is_master: + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self._notifier.on_new_replication_data() + + async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): + user_ip_cache_counter.inc() + + if self._is_master: + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) + + if self._server_notices_sender: + await self._server_notices_sender.on_user_ip(cmd.user_id) + + async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + if cmd.instance_name == self._instance_name: + # Ignore RDATA that are just our own echoes + return + + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) + raise + + # We linearize here for two reasons: + # 1. so we don't try and concurrently handle multiple rows for the + # same stream, and + # 2. so we don't race with getting a POSITION command and fetching + # missing RDATA. + with await self._position_linearizer.queue(cmd.stream_name): + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) + + async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + if cmd.instance_name == self._instance_name: + # Ignore POSITION that are just our own echoes + return + + logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) + + stream_name = cmd.stream_name + stream = self._streams.get(stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", stream_name) + return + + # We protect catching up with a linearizer in case the replication + # connection reconnects under us. + with await self._position_linearizer.queue(stream_name): + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) + + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) + ( + updates, + current_token, + missing_updates, + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) + + # TODO: add some tests for this + + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) + + self._streams_by_connection.setdefault(conn, set()).add(stream_name) + + async def on_REMOTE_SERVER_UP( + self, conn: AbstractConnection, cmd: RemoteServerUpCommand + ): + """"Called when get a new REMOTE_SERVER_UP command.""" + self._replication_data_handler.on_remote_server_up(cmd.data) + + self._notifier.notify_remote_server_up(cmd.data) + + # We relay to all other connections to ensure every instance gets the + # notification. + # + # When configured to use redis we'll always only have one connection and + # so this is a no-op (all instances will have already received the same + # REMOTE_SERVER_UP command). + # + # For direct TCP connections this will relay to all other connections + # connected to us. When on master this will correctly fan out to all + # other direct TCP clients and on workers there'll only be the one + # connection to master. + # + # (The logic here should also be sound if we have a mix of Redis and + # direct TCP connections so long as there is only one traffic route + # between two instances, but that is not currently supported). + self.send_command(cmd, ignore_conn=conn) + + def new_connection(self, connection: AbstractConnection): + """Called when we have a new connection. + """ + self._connections.append(connection) + + # If we are connected to replication as a client (rather than a server) + # we need to reset the reconnection delay on the client factory (which + # is used to do exponential back off when the connection drops). + # + # Ideally we would reset the delay when we've "fully established" the + # connection (for some definition thereof) to stop us from tightlooping + # on reconnection if something fails after this point and we drop the + # connection. Unfortunately, we don't really have a better definition of + # "fully established" than the connection being established. + if self._factory: + self._factory.resetDelay() + + # Tell the other end if we have any users currently syncing. + currently_syncing = ( + self._presence_handler.get_currently_syncing_users_for_replication() + ) + + now = self._clock.time_msec() + for user_id in currently_syncing: + connection.send_command( + UserSyncCommand(self._instance_id, user_id, True, now) + ) + + def lost_connection(self, connection: AbstractConnection): + """Called when a connection is closed/lost. + """ + # we no longer need _streams_by_connection for this connection. + streams = self._streams_by_connection.pop(connection, None) + if streams: + logger.info( + "Lost replication connection; streams now disconnected: %s", streams + ) + try: + self._connections.remove(connection) + except ValueError: + pass + + def connected(self) -> bool: + """Do we have any replication connections open? + + Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. + """ + return bool(self._connections) + + def send_command( + self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + ): + """Send a command to all connected connections. + + Args: + cmd + ignore_conn: If set don't send command to the given connection. + Used when relaying commands from one connection to all others. + """ + if self._connections: + for connection in self._connections: + if connection == ignore_conn: + continue + + try: + connection.send_command(cmd) + except Exception: + # We probably want to catch some types of exceptions here + # and log them as warnings (e.g. connection gone), but I + # can't find what those exception types they would be. + logger.exception( + "Failed to write command %s to connection %s", + cmd.NAME, + connection, + ) + else: + logger.warning("Dropping command as not connected: %r", cmd.NAME) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + + def stream_update(self, stream_name: str, token: str, data: Any): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) + + +UpdateToken = TypeVar("UpdateToken") +UpdateRow = TypeVar("UpdateRow") + + +def _batch_updates( + updates: Iterable[Tuple[UpdateToken, UpdateRow]] +) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]: + """Collect stream updates with the same token together + + Given a series of updates returned by Stream.get_updates_since(), collects + the updates which share the same stream_id together. + + For example: + + [(1, a), (1, b), (2, c), (3, d), (3, e)] + + becomes: + + [ + (1, [a, b]), + (2, [c]), + (3, [d, e]), + ] + """ + + update_iter = iter(updates) + + first_update = next(update_iter, None) + if first_update is None: + # empty input + return + + current_batch_token = first_update[0] + current_batch = [first_update[1]] + + for token, row in update_iter: + if token != current_batch_token: + # different token to the previous row: flush the previous + # batch and start anew + yield current_batch_token, current_batch + current_batch_token = token + current_batch = [] + + current_batch.append(row) + + # flush the final batch + yield current_batch_token, current_batch diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5ffdf2675d..4198eece71 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire:: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -48,45 +46,55 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ - +import abc import fcntl import logging import struct -from collections import defaultdict - -from six import iteritems, iterkeys +from typing import TYPE_CHECKING, List from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.stringutils import random_string - -from .commands import ( - COMMAND_MAP, +from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, + Command, ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, + parse_command_from_line, ) -from .streams import STREAMS_MAP +from synapse.types import Collection +from synapse.util import Clock +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) +tcp_inbound_commands_counter = Counter( + "synapse_replication_tcp_protocol_inbound_commands", + "Number of commands received from replication, by command and name of process connected to", + ["command", "name"], +) + +tcp_outbound_commands_counter = Counter( + "synapse_replication_tcp_protocol_outbound_commands", + "Number of commands sent to replication, by command and name of process connected to", + ["command", "name"], +) + # A list of all connected protocols. This allows us to send metrics about the # connections. connected_connections = [] @@ -115,7 +123,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): are only sent by the server. On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed - command. + command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -123,13 +131,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): delimiter = b"\n" - VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive - VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send + # Valid commands we expect to receive + VALID_INBOUND_COMMANDS = [] # type: Collection[str] + + # Valid commands we can send + VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): self.clock = clock + self.command_handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -143,14 +155,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.conn_id = random_string(5) # To dedupe in case of name clashes. # List of pending commands to send once we've established the connection - self.pending_commands = [] + self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = defaultdict(int) - self.outbound_commands_counter = defaultdict(int) - def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -169,6 +178,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.command_handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -196,60 +207,67 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): ) self.send_error("ping timeout") - def lineReceived(self, line): + def lineReceived(self, line: bytes): """Called when we've received a line """ if line.strip() == "": # Ignore blank lines return - line = line.decode("utf-8") - cmd_name, rest_of_line = line.split(" ", 1) + linestr = line.decode("utf-8") - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) + try: + cmd = parse_command_from_line(linestr) + except Exception as e: + logger.exception("[%s] failed to parse line: %r", self.id(), linestr) + self.send_error("failed to parse line: %r (%r):" % (e, linestr)) return - self.last_received_command = self.clock.time_msec() + if cmd.NAME not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd.NAME) + self.send_error("invalid command: %s", cmd.NAME) + return - self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1 - ) + self.last_received_command = self.clock.time_msec() - cmd_cls = COMMAND_MAP[cmd_name] - try: - cmd = cmd_cls.from_line(rest_of_line) - except Exception as e: - logger.exception( - "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line - ) - self.send_error( - "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) - ) - return + tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() # Now lets try and call on_<CMD_NAME> function run_as_background_process( "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) - def handle_command(self, cmd): + async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND> + First calls `self.on_<COMMAND>` if it exists, then calls + `self.command_handler.on_<COMMAND>` if it exists. This allows for + protocol level handling of commands (e.g. PINGs), before delegating to + the handler. Args: - cmd (synapse.replication.tcp.commands.Command): received command - - Returns: - Deferred + cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - return handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(self, cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): - logger.warn("[%s] Closing connection", self.id()) + logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() @@ -278,9 +296,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self._queue_command(cmd) return - self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1 - ) + tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc() + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -319,10 +336,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + async def on_PING(self, line): self.received_ping = True - def on_ERROR(self, cmd): + async def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -375,6 +392,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.command_handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -401,264 +420,73 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__( + self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + ): + super().__init__(clock, handler) self.server_name = server_name - self.streamer = streamer - - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() - - # The streams the client is currently subscribing to. - self.connecting_streams = set() - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} def connectionMade(self): self.send_command(ServerCommand(self.server_name)) - BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) + super().connectionMade() - def on_NAME(self, cmd): + async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - def on_USER_SYNC(self, cmd): - return self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - token = cmd.token - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - return make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - return self.subscribe_to_stream(stream_name, token) - - def on_FEDERATION_ACK(self, cmd): - return self.streamer.federation_ack(cmd.token) - - def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - def on_INVALIDATE_CACHE(self, cmd): - return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - def on_USER_IP(self, cmd): - return self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - @defer.inlineCallbacks - def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. - """ - self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) - - try: - # Get missing updates - updates, current_token = yield self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__(self, client_name, server_name, clock, handler): - BaseReplicationStreamProtocol.__init__(self, clock) + def __init__( + self, + hs: "HomeServer", + client_name: str, + server_name: str, + clock: Clock, + command_handler: "ReplicationCommandHandler", + ): + super().__init__(clock, command_handler) self.client_name = client_name self.server_name = server_name - self.handler = handler - - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = set() - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. - self.pending_batches = {} def connectionMade(self): self.send_command(NameCommand(self.client_name)) - BaseReplicationStreamProtocol.connectionMade(self) + super().connectionMade() # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) - - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.update_connection(self) + self.replicate() - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - - def on_SERVER(self, cmd): + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - return self.handler.on_rdata(stream_name, cmd.token, rows) + def replicate(self): + """Send the subscription request to the server + """ + logger.info("[%s] Subscribing to replication streams", self.id()) - def on_POSITION(self, cmd): - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() + self.send_command(ReplicateCommand()) - return self.handler.on_position(cmd.stream_name, cmd.token) - def on_SYNC(self, cmd): - return self.handler.on_sync(cmd.data) +class AbstractConnection(abc.ABC): + """An interface for replication connections. + """ - def replicate(self, stream_name, token): - """Send the subscription request to the server + @abc.abstractmethod + def send_command(self, cmd: Command): + """Send the command down the connection """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, - ) - - self.streams_connecting.add(stream_name) + pass - self.send_command(ReplicateCommand(stream_name, token)) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.update_connection(None) +# This tells python that `BaseReplicationStreamProtocol` implements the +# interface. +AbstractConnection.register(BaseReplicationStreamProtocol) # The following simply registers metrics for the replication connections @@ -696,7 +524,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): op = SIOCINQ else: op = SIOCOUTQ - size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0] + size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0] return size return 0 @@ -721,31 +549,3 @@ tcp_transport_kernel_read_buffer = LaterGauge( for p in connected_connections }, ) - - -tcp_inbound_commands = LaterGauge( - "synapse_replication_tcp_protocol_inbound_commands", - "", - ["command", "name"], - lambda: { - (k, p.name): count - for p in connected_connections - for k, count in iteritems(p.inbound_commands_counter) - }, -) - -tcp_outbound_commands = LaterGauge( - "synapse_replication_tcp_protocol_outbound_commands", - "", - ["command", "name"], - lambda: { - (k, p.name): count - for p in connected_connections - for k, count in iteritems(p.outbound_commands_counter) - }, -) - -# number of updates received for each RDATA stream -inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] -) diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py new file mode 100644 index 0000000000..e776b63183 --- /dev/null +++ b/synapse/replication/tcp/redis.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +import txredisapi + +from synapse.logging.context import make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import ( + Command, + ReplicateCommand, + parse_command_from_line, +) +from synapse.replication.tcp.protocol import ( + AbstractConnection, + tcp_inbound_commands_counter, + tcp_outbound_commands_counter, +) + +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): + """Connection to redis subscribed to replication stream. + + This class fulfils two functions: + + (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis + connection, parsing *incoming* messages into replication commands, and passing them + to `ReplicationCommandHandler` + + (b) it implements the AbstractConnection API, where it sends *outgoing* commands + onto outbound_redis_connection. + + Due to the vagaries of `txredisapi` we don't want to have a custom + constructor, so instead we expect the defined attributes below to be set + immediately after initialisation. + + Attributes: + handler: The command handler to handle incoming commands. + stream_name: The *redis* stream name to subscribe to and publish from + (not anything to do with Synapse replication streams). + outbound_redis_connection: The connection to redis to use to send + commands. + """ + + handler = None # type: ReplicationCommandHandler + stream_name = None # type: str + outbound_redis_connection = None # type: txredisapi.RedisProtocol + + def connectionMade(self): + logger.info("Connected to redis") + super().connectionMade() + run_as_background_process("subscribe-replication", self._send_subscribe) + + async def _send_subscribe(self): + # it's important to make sure that we only send the REPLICATE command once we + # have successfully subscribed to the stream - otherwise we might miss the + # POSITION response sent back by the other end. + logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) + await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info( + "Successfully subscribed to redis stream, sending REPLICATE command" + ) + self.handler.new_connection(self) + await self._async_send_command(ReplicateCommand()) + logger.info("REPLICATE successfully sent") + + # We send out our positions when there is a new connection in case the + # other side missed updates. We do this for Redis connections as the + # otherside won't know we've connected and so won't issue a REPLICATE. + self.handler.send_positions_to_connection(self) + + def messageReceived(self, pattern: str, channel: str, message: str): + """Received a message from redis. + """ + + if message.strip() == "": + # Ignore blank lines + return + + try: + cmd = parse_command_from_line(message) + except Exception: + logger.exception( + "Failed to parse replication line: %r", message, + ) + return + + # We use "redis" as the name here as we don't have 1:1 connections to + # remote instances. + tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() + + # Now lets try and call on_<CMD_NAME> function + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd + ) + + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. + + By default delegates to on_<COMMAND>, which should return an awaitable. + + Args: + cmd: received command + """ + handled = False + + # First call any command handlers on this instance. These are for redis + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(self, cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) + + def connectionLost(self, reason): + logger.info("Lost connection to redis") + super().connectionLost(reason) + self.handler.lost_connection(self) + + def send_command(self, cmd: Command): + """Send a command if connection has been established. + + Args: + cmd (Command) + """ + run_as_background_process("send-cmd", self._async_send_command, cmd) + + async def _async_send_command(self, cmd: Command): + """Encode a replication command and send it over our outbound connection""" + string = "%s %s" % (cmd.NAME, cmd.to_line()) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + encoded_string = string.encode("utf-8") + + # We use "redis" as the name here as we don't have 1:1 connections to + # remote instances. + tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() + + await make_deferred_yieldable( + self.outbound_redis_connection.publish(self.stream_name, encoded_string) + ) + + +class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + """This is a reconnecting factory that connects to redis and immediately + subscribes to a stream. + + Args: + hs + outbound_redis_connection: A connection to redis that will be used to + send outbound commands (this is seperate to the redis connection + used to subscribe). + """ + + maxDelay = 5 + continueTrying = True + protocol = RedisSubscriber + + def __init__( + self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol + ): + + super().__init__() + + # This sets the password on the RedisFactory base class (as + # SubscriberFactory constructor doesn't pass it through). + self.password = hs.config.redis.redis_password + + self.handler = hs.get_tcp_replication() + self.stream_name = hs.hostname + + self.outbound_redis_connection = outbound_redis_connection + + def buildProtocol(self, addr): + p = super().buildProtocol(addr) # type: RedisSubscriber + + # We do this here rather than add to the constructor of `RedisSubcriber` + # as to do so would involve overriding `buildProtocol` entirely, however + # the base method does some other things than just instantiating the + # protocol. + p.handler = self.handler + p.outbound_redis_connection = self.outbound_redis_connection + p.stream_name = self.stream_name + p.password = self.password + + return p diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index d1e98428bc..41569305df 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -18,31 +18,17 @@ import logging import random -from six import itervalues - from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.protocol import Factory -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.metrics import Measure, measure_func - -from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP -from .streams.federation import FederationStream +from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.util.metrics import Measure stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) -user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") -federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") -remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) -user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -52,13 +38,23 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = ReplicationStreamer(hs) + self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server_name + # If we've created a `ReplicationStreamProtocolFactory` then we're + # almost certainly registering a replication listener, so let's ensure + # that we've started a `ReplicationStreamer` instance to actually push + # data. + # + # (This is a bit of a weird place to do this, but the alternatives such + # as putting this in `HomeServer.setup()`, requires either passing the + # listener config again or always starting a `ReplicationStreamer`.) + hs.get_replication_streamer() + def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.streamer + self.server_name, self.clock, self.command_handler ) @@ -71,66 +67,22 @@ class ReplicationStreamer(object): def __init__(self, hs): self.store = hs.get_datastore() - self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.notifier = hs.get_notifier() - self._server_notices_sender = hs.get_server_notices_sender() + self._instance_name = hs.get_instance_name() self._replication_torture_level = hs.config.replication_torture_level - # Current connections. - self.connections = [] - - LaterGauge( - "synapse_replication_tcp_resource_total_connections", - "", - [], - lambda: len(self.connections), - ) - - # List of streams that clients can subscribe to. - # We only support federation stream if federation sending hase been - # disabled on the master. - self.streams = [ - stream(hs) - for stream in itervalues(STREAMS_MAP) - if stream != FederationStream or not hs.config.send_federation - ] - - self.streams_by_name = {stream.NAME: stream for stream in self.streams} - - LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", - "", - ["stream_name"], - lambda: { - (stream_name,): len( - [ - conn - for conn in self.connections - if stream_name in conn.replication_streams - ] - ) - for stream_name in self.streams_by_name - }, - ) - - self.federation_sender = None - if not hs.config.send_federation: - self.federation_sender = hs.get_federation_sender() - self.notifier.add_replication_callback(self.on_notifier_poke) # Keeps track of whether we are currently checking for updates self.is_looping = False self.pending_updates = False - hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) + self.command_handler = hs.get_tcp_replication() - def on_shutdown(self): - # close all connections on shutdown - for conn in self.connections: - conn.send_error("server shutting down") + # Set of streams to replicate. + self.streams = self.command_handler.get_streams_to_replicate() def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the @@ -139,7 +91,7 @@ class ReplicationStreamer(object): This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.connections: + if not self.command_handler.connected(): # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall beihind forever for stream in self.streams: @@ -154,8 +106,7 @@ class ReplicationStreamer(object): run_as_background_process("replication_notifier", self._run_notifier_loop) - @defer.inlineCallbacks - def _run_notifier_loop(self): + async def _run_notifier_loop(self): self.is_looping = True try: @@ -166,11 +117,6 @@ class ReplicationStreamer(object): self.pending_updates = False with Measure(self.clock, "repl.stream.get_updates"): - # First we tell the streams that they should update their - # current tokens. - for stream in self.streams: - stream.advance_current_token() - all_streams = self.streams if self._replication_torture_level is not None: @@ -180,11 +126,13 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.upto_token: + if stream.last_token == stream.current_token( + self._instance_name + ): continue if self._replication_torture_level: - yield self.clock.sleep( + await self.clock.sleep( self._replication_torture_level / 1000.0 ) @@ -192,18 +140,17 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.upto_token, + stream.current_token(self._instance_name), ) try: - updates, current_token = yield stream.get_updates() + updates, current_token, limited = await stream.get_updates() + self.pending_updates |= limited except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise logger.debug( - "Sending %d updates to %d connections", - len(updates), - len(self.connections), + "Sending %d updates", len(updates), ) if updates: @@ -219,102 +166,19 @@ class ReplicationStreamer(object): # token. See RdataCommand for more details. batched_updates = _batch_updates(updates) - for conn in self.connections: - for token, row in batched_updates: - try: - conn.stream_update(stream.NAME, token, row) - except Exception: - logger.exception("Failed to replicate") + for token, row in batched_updates: + try: + self.command_handler.stream_update( + stream.NAME, token, row + ) + except Exception: + logger.exception("Failed to replicate") logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False self.is_looping = False - @measure_func("repl.get_stream_updates") - def get_stream_updates(self, stream_name, token): - """For a given stream get all updates since token. This is called when - a client first subscribes to a stream. - """ - stream = self.streams_by_name.get(stream_name, None) - if not stream: - raise Exception("unknown stream %s", stream_name) - - return stream.get_updates_since(token) - - @measure_func("repl.federation_ack") - def federation_ack(self, token): - """We've received an ack for federation stream from a client. - """ - federation_ack_counter.inc() - if self.federation_sender: - self.federation_sender.federation_ack(token) - - @measure_func("repl.on_user_sync") - @defer.inlineCallbacks - def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): - """A client has started/stopped syncing on a worker. - """ - user_sync_counter.inc() - yield self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms - ) - - @measure_func("repl.on_remove_pusher") - @defer.inlineCallbacks - def on_remove_pusher(self, app_id, push_key, user_id): - """A client has asked us to remove a pusher - """ - remove_pusher_counter.inc() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id=app_id, pushkey=push_key, user_id=user_id - ) - - self.notifier.on_new_replication_data() - - @measure_func("repl.on_invalidate_cache") - def on_invalidate_cache(self, cache_func, keys): - """The client has asked us to invalidate a cache - """ - invalidate_cache_counter.inc() - getattr(self.store, cache_func).invalidate(tuple(keys)) - - @measure_func("repl.on_user_ip") - @defer.inlineCallbacks - def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """The client saw a user request - """ - user_ip_cache_counter.inc() - yield self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen - ) - yield self._server_notices_sender.on_user_ip(user_id) - - def send_sync_to_all_connections(self, data): - """Sends a SYNC command to all clients. - - Used in tests. - """ - for conn in self.connections: - conn.send_sync(data) - - def new_connection(self, connection): - """A new client connection has been established - """ - self.connections.append(connection) - - def lost_connection(self, connection): - """A client connection has been lost - """ - try: - self.connections.remove(connection) - except ValueError: - pass - - # We need to tell the presence handler that the connection has been - # lost so that it can handle any ongoing syncs on that connection. - self.presence_handler.update_external_syncs_clear(connection.conn_id) - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 634f636dc9..d1a61c3314 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -25,25 +25,63 @@ Each stream is defined by the following information: update_function: The function that returns a list of updates between two tokens """ -from . import _base, events, federation +from synapse.replication.tcp.streams._base import ( + AccountDataStream, + BackfillStream, + CachesStream, + DeviceListsStream, + GroupServerStream, + PresenceStream, + PublicRoomsStream, + PushersStream, + PushRulesStream, + ReceiptsStream, + Stream, + TagAccountDataStream, + ToDeviceStream, + TypingStream, + UserSignatureStream, +) +from synapse.replication.tcp.streams.events import EventsStream +from synapse.replication.tcp.streams.federation import FederationStream STREAMS_MAP = { stream.NAME: stream for stream in ( - events.EventsStream, - _base.BackfillStream, - _base.PresenceStream, - _base.TypingStream, - _base.ReceiptsStream, - _base.PushRulesStream, - _base.PushersStream, - _base.CachesStream, - _base.PublicRoomsStream, - _base.DeviceListsStream, - _base.ToDeviceStream, - federation.FederationStream, - _base.TagAccountDataStream, - _base.AccountDataStream, - _base.GroupServerStream, + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + GroupServerStream, + UserSignatureStream, ) } + +__all__ = [ + "STREAMS_MAP", + "Stream", + "BackfillStream", + "PresenceStream", + "TypingStream", + "ReceiptsStream", + "PushRulesStream", + "PushersStream", + "CachesStream", + "PublicRoomsStream", + "DeviceListsStream", + "ToDeviceStream", + "TagAccountDataStream", + "AccountDataStream", + "GroupServerStream", + "UserSignatureStream", +] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f03111c259..4acefc8a96 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,102 +14,84 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import itertools +import heapq import logging from collections import namedtuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) + +import attr + +from synapse.replication.http.streams import ReplicationGetStreamUpdates -from twisted.internet import defer +if TYPE_CHECKING: + import synapse.server logger = logging.getLogger(__name__) +# the number of rows to request from an update_function. +_STREAM_UPDATE_TARGET_ROW_COUNT = 100 -MAX_EVENTS_BEHIND = 10000 -BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), -) -PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), -) -TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) -) -ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), -) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str -PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool -) -CachesStreamRow = namedtuple( - "CachesStreamRow", - ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int -) -PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), -) -DeviceListsStreamRow = namedtuple( - "DeviceListsStreamRow", ("user_id", "destination") # str # str -) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str -TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict -) -AccountDataStreamRow = namedtuple( - "AccountDataStream", - ("user_id", "room_id", "data_type", "data"), # str # str # str # dict -) -GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict -) +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# The type of a stream update row, after JSON deserialisation, but before +# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's +# just a row from a database query, though this is dependent on the stream in question. +# +StreamRow = TypeVar("StreamRow", bound=Tuple) + +# The type returned by the update_function of a stream, as well as get_updates(), +# get_updates_since, etc. +# +# It consists of a triplet `(updates, new_last_token, limited)`, where: +# * `updates` is a list of `(token, row)` entries. +# * `new_last_token` is the new position in stream. +# * `limited` is whether there are more updates to fetch. +# +StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] + +# The type of an update_function for a stream +# +# The arguments are: +# +# * instance_name: the writer of the stream +# * from_token: the previous stream token: the starting point for fetching the +# updates +# * to_token: the new stream token: the point to get updates up to +# * target_row_count: a target for the number of rows to be returned. +# +# The update_function is expected to return up to _approximately_ target_row_count rows. +# If there are more updates available, it should set `limited` in the result, and +# it will be called again to get the next batch. +# +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last - time it was called up until the point `advance_current_token` was called. + time it was called. """ - NAME = None # The name of the stream - ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. - _LIMITED = True # Whether the update function takes a limit + NAME = None # type: str # The name of the stream + # The type of the row. Used by the default impl of parse_row. + ROW_TYPE = None # type: Any @classmethod - def parse_row(cls, row): + def parse_row(cls, row: StreamRow): """Parse a row received over replication By default, assumes that the row data is an array object and passes its contents @@ -123,102 +105,138 @@ class Stream(object): """ return cls.ROW_TYPE(*row) - def __init__(self, hs): - # The token from which we last asked for updates - self.last_token = self.current_token() - - # The token that we will get updates up to - self.upto_token = self.current_token() + def __init__( + self, + local_instance_name: str, + current_token_function: Callable[[str], Token], + update_function: UpdateFunction, + ): + """Instantiate a Stream + + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. + + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). + + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. - def advance_current_token(self): - """Updates `upto_token` to "now", which updates up until which point - get_updates[_since] will fetch rows till. + Args: + local_instance_name: The instance name of the current process + current_token_function: callback to get the current token, as above + update_function: callback go get stream updates, as above """ - self.upto_token = self.current_token() + self.local_instance_name = local_instance_name + self.current_token = current_token_function + self.update_function = update_function + + # The token from which we last asked for updates + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.upto_token = self.current_token() - self.last_token = self.upto_token + self.last_token = self.current_token(self.local_instance_name) - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or - since the stream was constructed if it hadn't been called before), - until the `upto_token` + since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - updates, current_token = yield self.get_updates_since(self.last_token) + current_token = self.current_token(self.local_instance_name) + updates, current_token, limited = await self.get_updates_since( + self.local_instance_name, self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since( + self, instance_name: str, from_token: Token, upto_token: Token + ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ - if from_token in ("NOW", "now"): - return [], self.upto_token - - current_token = self.upto_token from_token = int(from_token) - if from_token == current_token: - return [], current_token + if from_token == upto_token: + return [], upto_token, False - if self._LIMITED: - rows = yield self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 - ) + updates, upto_token, limited = await self.update_function( + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + ) + return updates, upto_token, limited - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - else: - rows = yield self.update_function(from_token, current_token) +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + +def db_query_to_update_function( + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> UpdateFunction: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(instance_name, from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) + return updates, upto_token, limited - return updates, current_token + return update_function - def current_token(self): - """Gets the current token of the underlying streams. Should be provided - by the sub classes - Returns: - int - """ - raise NotImplementedError() +def make_http_update_function(hs, stream_name: str) -> UpdateFunction: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ - def update_function(self, from_token, current_token, limit=None): - """Get updates between from_token and to_token. If Stream._LIMITED is - True then limit is provided, otherwise it's not. + client = ReplicationGetStreamUpdates.make_client(hs) - Returns: - Deferred(list(tuple)): the first entry in the tuple is the token for - that update, and the rest of the tuple gets used to construct - a ``ROW_TYPE`` instance - """ - raise NotImplementedError() + async def update_function( + instance_name: str, from_token: int, upto_token: int, limit: int + ) -> StreamUpdateResult: + result = await client( + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + ) + return result["updates"], result["upto_token"], result["limited"] + + return update_function class BackfillStream(Stream): @@ -226,94 +244,170 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ + BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), + ) + NAME = "backfill" ROW_TYPE = BackfillStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_backfill_token - self.update_function = store.get_all_new_backfill_event_rows - - super(BackfillStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_backfill_token), + db_query_to_update_function(store.get_all_new_backfill_event_rows), + ) class PresenceStream(Stream): + PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), + ) + NAME = "presence" - _LIMITED = False ROW_TYPE = PresenceStreamRow def __init__(self, hs): store = hs.get_datastore() - presence_handler = hs.get_presence_handler() - self.current_token = store.get_current_presence_token - self.update_function = presence_handler.get_all_presence_updates + if hs.config.worker_app is None: + # on the master, query the presence handler + presence_handler = hs.get_presence_handler() + update_function = db_query_to_update_function( + presence_handler.get_all_presence_updates + ) + else: + # Query master process + update_function = make_http_update_function(hs, self.NAME) - super(PresenceStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, + ) class TypingStream(Stream): + TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) + ) + NAME = "typing" - _LIMITED = False ROW_TYPE = TypingStreamRow def __init__(self, hs): typing_handler = hs.get_typing_handler() - self.current_token = typing_handler.get_current_token - self.update_function = typing_handler.get_all_typing_updates + if hs.config.worker_app is None: + # on the master, query the typing handler + update_function = db_query_to_update_function( + typing_handler.get_all_typing_updates + ) + else: + # Query master process + update_function = make_http_update_function(hs, self.NAME) - super(TypingStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, + ) class ReceiptsStream(Stream): + ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), + ) + NAME = "receipts" ROW_TYPE = ReceiptsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_receipt_stream_id - self.update_function = store.get_all_updated_receipts - - super(ReceiptsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_max_receipt_stream_id), + db_query_to_update_function(store.get_all_updated_receipts), + ) class PushRulesStream(Stream): """A user has changed their push rules """ + PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow def __init__(self, hs): self.store = hs.get_datastore() - super(PushRulesStream, self).__init__(hs) + super(PushRulesStream, self).__init__( + hs.get_instance_name(), self._current_token, self._update_function + ) - def current_token(self): + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): """A user has added/changed/removed a pusher """ + PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool + ) + NAME = "pushers" ROW_TYPE = PushersStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_pushers_stream_token - self.update_function = store.get_all_updated_pushers_rows - - super(PushersStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_pushers_stream_token), + db_query_to_update_function(store.get_all_updated_pushers_rows), + ) class CachesStream(Stream): @@ -321,120 +415,235 @@ class CachesStream(Stream): the cache on the workers """ + @attr.s + class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + NAME = "caches" ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + self.store.get_cache_stream_token, + self._update_function, + ) - self.current_token = store.get_cache_stream_token - self.update_function = store.get_all_updated_caches + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit + ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - super(CachesStream, self).__init__(hs) + return updates, upto_token, limited class PublicRoomsStream(Stream): """The public rooms list changed """ + PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), + ) + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_current_public_room_stream_id - self.update_function = store.get_all_new_public_rooms - - super(PublicRoomsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_current_public_room_stream_id), + db_query_to_update_function(store.get_all_new_public_rooms), + ) class DeviceListsStream(Stream): - """Someone added/changed/removed a device + """Either a user has updated their devices or a remote server needs to be + told about a device update. """ + @attr.s + class DeviceListsStreamRow: + entity = attr.ib(type=str) + NAME = "device_lists" - _LIMITED = False ROW_TYPE = DeviceListsStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_device_list_changes_for_remotes - - super(DeviceListsStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_device_stream_token), + db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + ) class ToDeviceStream(Stream): """New to_device messages for a client """ + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_to_device_stream_token - self.update_function = store.get_all_new_device_messages - - super(ToDeviceStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_to_device_stream_token), + db_query_to_update_function(store.get_all_new_device_messages), + ) class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict + ) + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs): store = hs.get_datastore() - - self.current_token = store.get_max_account_data_stream_id - self.update_function = store.get_all_updated_tags - - super(TagAccountDataStream, self).__init__(hs) + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_max_account_data_stream_id), + db_query_to_update_function(store.get_all_updated_tags), + ) class AccountDataStream(Stream): """Global or per room account data was changed """ + AccountDataStreamRow = namedtuple( + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str + ) + NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(self.store.get_max_account_data_stream_id), + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit + ) - self.current_token = self.store.get_max_account_data_stream_id + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True - super(AccountDataStream, self).__init__(hs) + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit + ) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) + for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type, content) - for stream_id, user_id, account_data_type, content in global_results + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results ) - return results + # We need to return a sorted list, so merge them together. + # + # Note: We order only by the stream ID to work around a bug where the + # same stream ID could appear in both `global_rows` and `room_rows`, + # leading to a comparison between the data tuples. The comparison could + # fail due to attempting to compare the `room_id` which results in a + # `TypeError` from comparing a `str` vs `None`. + updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) + return updates, to_token, limited class GroupServerStream(Stream): + GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict + ) + NAME = "groups" ROW_TYPE = GroupsStreamRow def __init__(self, hs): store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_group_stream_token), + db_query_to_update_function(store.get_all_groups_changes), + ) - self.current_token = store.get_group_stream_token - self.update_function = store.get_all_groups_changes - super(GroupServerStream, self).__init__(hs) +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + + NAME = "user_signature" + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + super().__init__( + hs.get_instance_name(), + current_token_without_instance(store.get_device_stream_token), + db_query_to_update_function( + store.get_all_user_signature_changes_for_remotes + ), + ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index d97669c886..f370390331 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,13 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import heapq +from collections import Iterable +from typing import List, Tuple, Type import attr -from twisted.internet import defer - -from ._base import Stream +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -63,7 +64,8 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overriden in sub classes. + TypeId = None # type: str @classmethod def from_data(cls, data): @@ -99,9 +101,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): event_id = attr.ib() # str, optional -TypeToRow = { - Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) -} +_EventRows = ( + EventsStreamEventRow, + EventsStreamCurrentStateRow, +) # type: Tuple[Type[BaseEventsStreamRow], ...] + +TypeToRow = {Row.TypeId: Row for Row in _EventRows} class EventsStream(Stream): @@ -112,29 +117,107 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() - self.current_token = self._store.get_current_events_token - - super(EventsStream, self).__init__(hs) - - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( - from_token, current_token, limit - ) - event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows + super().__init__( + hs.get_instance_name(), + current_token_without_instance(self._store.get_current_events_token), + self._update_function, ) - state_rows = yield self._store.get_all_updated_current_state_deltas( - from_token, current_token, limit - ) - state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows + async def _update_function( + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, + ) -> StreamUpdateResult: + + # the events stream merges together three separate sources: + # * new events + # * current_state changes + # * events which were previously outliers, but have now been de-outliered. + # + # The merge operation is complicated by the fact that we only have a single + # "stream token" which is supposed to indicate how far we have got through + # all three streams. It's therefore no good to return rows 1-1000 from the + # "new events" table if the state_deltas are limited to rows 1-100 by the + # target_row_count. + # + # In other words: we must pick a new upper limit, and must return *all* rows + # up to that point for each of the three sources. + # + # Start by trying to split the target_row_count up. We expect to have a + # negligible number of ex-outliers, and a rough approximation based on recent + # traffic on sw1v.org shows that there are approximately the same number of + # event rows between a given pair of stream ids as there are state + # updates, so let's split our target_row_count among those two types. The target + # is only an approximation - it doesn't matter if we end up going a bit over it. + + target_row_count //= 2 + + # now we fetch up to that many rows from the events table + + event_rows = await self._store.get_all_new_forward_event_rows( + from_token, current_token, target_row_count + ) # type: List[Tuple] + + # we rely on get_all_new_forward_event_rows strictly honouring the limit, so + # that we know it is safe to just take upper_limit = event_rows[-1][0]. + assert ( + len(event_rows) <= target_row_count + ), "get_all_new_forward_event_rows did not honour row limit" + + # if we hit the limit on event_updates, there's no point in going beyond the + # last stream_id in the batch for the other sources. + + if len(event_rows) == target_row_count: + limited = True + upper_limit = event_rows[-1][0] # type: int + else: + limited = False + upper_limit = current_token + + # next up is the state delta table. + ( + state_rows, + upper_limit, + state_rows_limited, + ) = await self._store.get_all_updated_current_state_deltas( + from_token, upper_limit, target_row_count ) - all_updates = heapq.merge(event_updates, state_updates) + limited = limited or state_rows_limited + + # finally, fetch the ex-outliers rows. We assume there are few enough of these + # not to bother with the limit. + + ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( + from_token, upper_limit + ) # type: List[Tuple] + + # we now need to turn the raw database rows returned into tuples suitable + # for the replication protocol (basically, we add an identifier to + # distinguish the row type). At the same time, we can limit the event_rows + # to the max stream_id from state_rows. - return all_updates + event_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in event_rows + if stream_id <= upper_limit + ) # type: Iterable[Tuple[int, Tuple]] + + state_updates = ( + (stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) + for (stream_id, *rest) in state_rows + ) # type: Iterable[Tuple[int, Tuple]] + + ex_outliers_updates = ( + (stream_id, (EventsStreamEventRow.TypeId, rest)) + for (stream_id, *rest) in ex_outliers_rows + ) # type: Iterable[Tuple[int, Tuple]] + + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) + return updates, upper_limit, limited @classmethod def parse_row(cls, row): diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index dc2484109d..9bcd13b009 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,14 +15,10 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream - -FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), +from synapse.replication.tcp.streams._base import ( + Stream, + current_token_without_instance, + make_http_update_function, ) @@ -31,13 +27,47 @@ class FederationStream(Stream): sending disabled. """ + FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), + ) + NAME = "federation" ROW_TYPE = FederationStreamRow def __init__(self, hs): - federation_sender = hs.get_federation_sender() + if hs.config.worker_app is None: + # master process: get updates from the FederationRemoteSendQueue. + # (if the master is configured to send federation itself, federation_sender + # will be a real FederationSender, which has stubs for current_token and + # get_replication_rows.) + federation_sender = hs.get_federation_sender() + current_token = current_token_without_instance( + federation_sender.get_current_token + ) + update_function = federation_sender.get_replication_rows + + elif hs.should_send_federation(): + # federation sender: Query master process + update_function = make_http_update_function(hs, self.NAME) + current_token = self._stub_current_token + + else: + # other worker: stub out the update function (we're not interested in + # any updates so when we get a POSITION we do nothing) + update_function = self._stub_update_function + current_token = self._stub_current_token + + super().__init__(hs.get_instance_name(), current_token, update_function) - self.current_token = federation_sender.get_current_token - self.update_function = federation_sender.get_replication_rows + @staticmethod + def _stub_current_token(instance_name: str) -> int: + # dummy current-token method for use on workers + return 0 - super(FederationStream, self).__init__(hs) + @staticmethod + async def _stub_update_function(instance_name, from_token, upto_token, limit): + return [], upto_token, False diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html new file mode 100644 index 0000000000..cc4ab07e09 --- /dev/null +++ b/synapse/res/templates/add_threepid.html @@ -0,0 +1,9 @@ +<html> +<body> + <p>A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:</p> + + <a href="{{ link }}">{{ link }}</a> + + <p>If this was not you, you can safely ignore this email. Thank you.</p> +</body> +</html> diff --git a/synapse/res/templates/add_threepid.txt b/synapse/res/templates/add_threepid.txt new file mode 100644 index 0000000000..a60c1ff659 --- /dev/null +++ b/synapse/res/templates/add_threepid.txt @@ -0,0 +1,6 @@ +A request to add an email address to your Matrix account has been received. If this was you, +please click the link below to confirm adding this email: + +{{ link }} + +If this was not you, you can safely ignore this email. Thank you. diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html new file mode 100644 index 0000000000..441d11c846 --- /dev/null +++ b/synapse/res/templates/add_threepid_failure.html @@ -0,0 +1,8 @@ +<html> +<head></head> +<body> +<p>The request failed for the following reason: {{ failure_reason }}.</p> + +<p>No changes have been made to your account.</p> +</body> +</html> diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html new file mode 100644 index 0000000000..fbd6e4018f --- /dev/null +++ b/synapse/res/templates/add_threepid_success.html @@ -0,0 +1,6 @@ +<html> +<head></head> +<body> +<p>Your email has now been validated, please return to your client. You may now close this window.</p> +</body> +</html> diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html index f0d7c66e1b..6b94d8c367 100644 --- a/synapse/res/templates/notice_expiry.html +++ b/synapse/res/templates/notice_expiry.html @@ -30,7 +30,7 @@ <tr> <td colspan="2"> <div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div> - <div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div> + <div class="noticetext">To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):</div> <div class="noticetext"><a href="{{ url }}">{{ url }}</a></div> </td> </tr> diff --git a/synapse/res/templates/notice_expiry.txt b/synapse/res/templates/notice_expiry.txt index 41f1c4279c..4ec27e8831 100644 --- a/synapse/res/templates/notice_expiry.txt +++ b/synapse/res/templates/notice_expiry.txt @@ -2,6 +2,6 @@ Hi {{ display_name }}, Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date. -To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab): +To extend the validity of your account, please click on the link below (or copy and paste it to a new browser tab): {{ url }} diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html new file mode 100644 index 0000000000..bfd6449c5d --- /dev/null +++ b/synapse/res/templates/saml_error.html @@ -0,0 +1,45 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <title>SSO error</title> +</head> +<body> + <p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p> + <p> + If you are seeing this page after clicking a link sent to you via email, make + sure you only click the confirmation link once, and that you open the + validation link in the same client you're logging in from. + </p> + <p> + Try logging in again from your Matrix client and if the problem persists + please contact the server's administrator. + </p> + + <script type="text/javascript"> + // Error handling to support Auth0 errors that we might get through a GET request + // to the validation endpoint. If an error is provided, it's either going to be + // located in the query string or in a query string-like URI fragment. + // We try to locate the error from any of these two locations, but if we can't + // we just don't print anything specific. + let searchStr = ""; + if (window.location.search) { + // window.location.searchParams isn't always defined when + // window.location.search is, so it's more reliable to parse the latter. + searchStr = window.location.search; + } else if (window.location.hash) { + // Replace the # with a ? so that URLSearchParams does the right thing and + // doesn't parse the first parameter incorrectly. + searchStr = window.location.hash.replace("#", "?"); + } + + // We might end up with no error in the URL, so we need to check if we have one + // to print one. + let errorDesc = new URLSearchParams(searchStr).get("error_description") + if (errorDesc) { + + document.getElementById("errormsg").innerText = ` ("${errorDesc}")`; + } + </script> +</body> +</html> \ No newline at end of file diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html new file mode 100644 index 0000000000..4eb8db9fb4 --- /dev/null +++ b/synapse/res/templates/sso_account_deactivated.html @@ -0,0 +1,10 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <title>SSO account deactivated</title> +</head> + <body> + <p>This account has been deactivated.</p> + </body> +</html> diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html new file mode 100644 index 0000000000..0d9de9d465 --- /dev/null +++ b/synapse/res/templates/sso_auth_confirm.html @@ -0,0 +1,14 @@ +<html> +<head> + <title>Authentication</title> +</head> + <body> + <div> + <p> + A client is trying to {{ description | e }}. To confirm this action, + <a href="{{ redirect_url | e }}">re-authenticate with single sign-on</a>. + If you did not expect this, your account may be compromised! + </p> + </div> + </body> +</html> diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html new file mode 100644 index 0000000000..03f1419467 --- /dev/null +++ b/synapse/res/templates/sso_auth_success.html @@ -0,0 +1,18 @@ +<html> +<head> + <title>Authentication Successful</title> + <script> + if (window.onAuthDone) { + window.onAuthDone(); + } else if (window.opener && window.opener.postMessage) { + window.opener.postMessage("authDone", "*"); + } + </script> +</head> + <body> + <div> + <p>Thank you</p> + <p>You may now close this window and return to the application</p> + </div> + </body> +</html> diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html new file mode 100644 index 0000000000..43a211386b --- /dev/null +++ b/synapse/res/templates/sso_error.html @@ -0,0 +1,18 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <title>SSO error</title> +</head> +<body> + <p>Oops! Something went wrong during authentication.</p> + <p> + Try logging in again from your Matrix client and if the problem persists + please contact the server's administrator. + </p> + <p>Error: <code>{{ error }}</code></p> + {% if error_description %} + <pre><code>{{ error_description }}</code></pre> + {% endif %} +</body> +</html> diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html new file mode 100644 index 0000000000..20a15e1e74 --- /dev/null +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -0,0 +1,14 @@ +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <title>SSO redirect confirmation</title> +</head> + <body> + <p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p> + <p>If you don't recognise this address, you should ignore this and close this tab.</p> + <p> + <a href="{{ redirect_url | e }}">I trust this address</a> + </p> + </body> +</html> \ No newline at end of file diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 4a1fc2ec2b..46e458e95b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import ( keys, notifications, openid, + password_policy, read_marker, receipts, register, @@ -118,6 +119,7 @@ class ClientRestResource(JsonResource): capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 81b6bd8816..9eda592de9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -14,64 +14,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib -import hmac import logging import platform import re -from six import text_type -from six.moves import http_client - -from twisted.internet import defer - import synapse -from synapse.api.constants import Membership, UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import JsonResource -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_integer, - parse_json_object_from_request, - parse_string, -) +from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.admin._base import ( assert_requester_is_admin, - assert_user_is_admin, historical_admin_path_patterns, ) +from synapse.rest.admin.devices import ( + DeleteDevicesRestServlet, + DeviceRestServlet, + DevicesRestServlet, +) +from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet +from synapse.rest.admin.rooms import ( + JoinRoomAliasServlet, + ListRoomRestServlet, + RoomRestServlet, + ShutdownRoomRestServlet, +) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet -from synapse.rest.admin.users import UserAdminServlet -from synapse.types import UserID, create_requester +from synapse.rest.admin.users import ( + AccountValidityRenewServlet, + DeactivateAccountRestServlet, + ResetPasswordRestServlet, + SearchUsersRestServlet, + UserAdminServlet, + UserRegisterServlet, + UserRestServletV2, + UsersRestServlet, + UsersRestServletV2, + WhoisRestServlet, +) from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) -class UsersRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$") - - def __init__(self, hs): - self.hs = hs - self.auth = hs.get_auth() - self.handlers = hs.get_handlers() - - @defer.inlineCallbacks - def on_GET(self, request, user_id): - target_user = UserID.from_string(user_id) - yield assert_requester_is_admin(self.auth, request) - - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") - - ret = yield self.handlers.admin_handler.get_users() - - return 200, ret - - class VersionServlet(RestServlet): PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),) @@ -85,161 +71,6 @@ class VersionServlet(RestServlet): return 200, self.res -class UserRegisterServlet(RestServlet): - """ - Attributes: - NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted - nonces (dict[str, int]): The nonces that we will accept. A dict of - nonce to the time it was generated, in int seconds. - """ - - PATTERNS = historical_admin_path_patterns("/register") - NONCE_TIMEOUT = 60 - - def __init__(self, hs): - self.handlers = hs.get_handlers() - self.reactor = hs.get_reactor() - self.nonces = {} - self.hs = hs - - def _clear_old_nonces(self): - """ - Clear out old nonces that are older than NONCE_TIMEOUT. - """ - now = int(self.reactor.seconds()) - - for k, v in list(self.nonces.items()): - if now - v > self.NONCE_TIMEOUT: - del self.nonces[k] - - def on_GET(self, request): - """ - Generate a new nonce. - """ - self._clear_old_nonces() - - nonce = self.hs.get_secrets().token_hex(64) - self.nonces[nonce] = int(self.reactor.seconds()) - return 200, {"nonce": nonce} - - @defer.inlineCallbacks - def on_POST(self, request): - self._clear_old_nonces() - - if not self.hs.config.registration_shared_secret: - raise SynapseError(400, "Shared secret registration is not enabled") - - body = parse_json_object_from_request(request) - - if "nonce" not in body: - raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) - - nonce = body["nonce"] - - if nonce not in self.nonces: - raise SynapseError(400, "unrecognised nonce") - - # Delete the nonce, so it can't be reused, even if it's invalid - del self.nonces[nonce] - - if "username" not in body: - raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON - ) - else: - if ( - not isinstance(body["username"], text_type) - or len(body["username"]) > 512 - ): - raise SynapseError(400, "Invalid username") - - username = body["username"].encode("utf-8") - if b"\x00" in username: - raise SynapseError(400, "Invalid username") - - if "password" not in body: - raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON - ) - else: - if ( - not isinstance(body["password"], text_type) - or len(body["password"]) > 512 - ): - raise SynapseError(400, "Invalid password") - - password = body["password"].encode("utf-8") - if b"\x00" in password: - raise SynapseError(400, "Invalid password") - - admin = body.get("admin", None) - user_type = body.get("user_type", None) - - if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") - - got_mac = body["mac"] - - want_mac = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), - digestmod=hashlib.sha1, - ) - want_mac.update(nonce.encode("utf8")) - want_mac.update(b"\x00") - want_mac.update(username) - want_mac.update(b"\x00") - want_mac.update(password) - want_mac.update(b"\x00") - want_mac.update(b"admin" if admin else b"notadmin") - if user_type: - want_mac.update(b"\x00") - want_mac.update(user_type.encode("utf8")) - want_mac = want_mac.hexdigest() - - if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): - raise SynapseError(403, "HMAC incorrect") - - # Reuse the parts of RegisterRestServlet to reduce code duplication - from synapse.rest.client.v2_alpha.register import RegisterRestServlet - - register = RegisterRestServlet(self.hs) - - user_id = yield register.registration_handler.register_user( - localpart=body["username"].lower(), - password=body["password"], - admin=bool(admin), - user_type=user_type, - ) - - result = yield register._create_registration_details(user_id, body) - return 200, result - - -class WhoisRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)") - - def __init__(self, hs): - self.hs = hs - self.auth = hs.get_auth() - self.handlers = hs.get_handlers() - - @defer.inlineCallbacks - def on_GET(self, request, user_id): - target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) - auth_user = requester.user - - if target_user != auth_user: - yield assert_user_is_admin(self.auth, auth_user) - - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only whois a local user") - - ret = yield self.handlers.admin_handler.get_whois(target_user) - - return 200, ret - - class PurgeHistoryRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns( "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" @@ -255,9 +86,8 @@ class PurgeHistoryRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - yield assert_requester_is_admin(self.auth, request) + async def on_POST(self, request, room_id, event_id): + await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) @@ -270,12 +100,12 @@ class PurgeHistoryRestServlet(RestServlet): event_id = body.get("purge_up_to_event_id") if event_id is not None: - event = yield self.store.get_event(event_id) + event = await self.store.get_event(event_id) if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - token = yield self.store.get_topological_token_for_event(event_id) + token = await self.store.get_topological_token_for_event(event_id) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: @@ -285,15 +115,13 @@ class PurgeHistoryRestServlet(RestServlet): 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON ) - stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) + stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) - r = ( - yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering - ) + r = await self.store.get_room_event_before_stream_ordering( + room_id, stream_ordering ) if not r: - logger.warn( + logger.warning( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", ts, @@ -318,7 +146,7 @@ class PurgeHistoryRestServlet(RestServlet): errcode=Codes.BAD_JSON, ) - purge_id = yield self.pagination_handler.start_purge_history( + purge_id = self.pagination_handler.start_purge_history( room_id, token, delete_local_events=delete_local_events ) @@ -339,9 +167,8 @@ class PurgeHistoryStatusRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, purge_id): - yield assert_requester_is_admin(self.auth, request) + async def on_GET(self, request, purge_id): + await assert_requester_is_admin(self.auth, request) purge_status = self.pagination_handler.get_purge_status(purge_id) if purge_status is None: @@ -350,375 +177,6 @@ class PurgeHistoryStatusRestServlet(RestServlet): return 200, purge_status.asdict() -class DeactivateAccountRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)") - - def __init__(self, hs): - self._deactivate_account_handler = hs.get_deactivate_account_handler() - self.auth = hs.get_auth() - - @defer.inlineCallbacks - def on_POST(self, request, target_user_id): - yield assert_requester_is_admin(self.auth, request) - body = parse_json_object_from_request(request, allow_empty_body=True) - erase = body.get("erase", False) - if not isinstance(erase, bool): - raise SynapseError( - http_client.BAD_REQUEST, - "Param 'erase' must be a boolean, if given", - Codes.BAD_JSON, - ) - - UserID.from_string(target_user_id) - - result = yield self._deactivate_account_handler.deactivate_account( - target_user_id, erase - ) - if result: - id_server_unbind_result = "success" - else: - id_server_unbind_result = "no-support" - - return 200, {"id_server_unbind_result": id_server_unbind_result} - - -class ShutdownRoomRestServlet(RestServlet): - """Shuts down a room by removing all local users from the room and blocking - all future invites and joins to the room. Any local aliases will be repointed - to a new room created by `new_room_user_id` and kicked users will be auto - joined to the new room. - """ - - PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)") - - DEFAULT_MESSAGE = ( - "Sharing illegal content on this server is not permitted and rooms in" - " violation will be blocked." - ) - - def __init__(self, hs): - self.hs = hs - self.store = hs.get_datastore() - self.state = hs.get_state_handler() - self._room_creation_handler = hs.get_room_creation_handler() - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) - - content = parse_json_object_from_request(request) - assert_params_in_dict(content, ["new_room_user_id"]) - new_room_user_id = content["new_room_user_id"] - - room_creator_requester = create_requester(new_room_user_id) - - message = content.get("message", self.DEFAULT_MESSAGE) - room_name = content.get("room_name", "Content Violation Notification") - - info = yield self._room_creation_handler.create_room( - room_creator_requester, - config={ - "preset": "public_chat", - "name": room_name, - "power_level_content_override": {"users_default": -10}, - }, - ratelimit=False, - ) - new_room_id = info["room_id"] - - requester_user_id = requester.user.to_string() - - logger.info( - "Shutting down room %r, joining to new room: %r", room_id, new_room_id - ) - - # This will work even if the room is already blocked, but that is - # desirable in case the first attempt at blocking the room failed below. - yield self.store.block_room(room_id, requester_user_id) - - users = yield self.state.get_current_users_in_room(room_id) - kicked_users = [] - failed_to_kick_users = [] - for user_id in users: - if not self.hs.is_mine_id(user_id): - continue - - logger.info("Kicking %r from %r...", user_id, room_id) - - try: - target_requester = create_requester(user_id) - yield self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=room_id, - action=Membership.LEAVE, - content={}, - ratelimit=False, - require_consent=False, - ) - - yield self.room_member_handler.forget(target_requester.user, room_id) - - yield self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=new_room_id, - action=Membership.JOIN, - content={}, - ratelimit=False, - require_consent=False, - ) - - kicked_users.append(user_id) - except Exception: - logger.exception( - "Failed to leave old room and join new room for %r", user_id - ) - failed_to_kick_users.append(user_id) - - yield self.event_creation_handler.create_and_send_nonmember_event( - room_creator_requester, - { - "type": "m.room.message", - "content": {"body": message, "msgtype": "m.text"}, - "room_id": new_room_id, - "sender": new_room_user_id, - }, - ratelimit=False, - ) - - aliases_for_room = yield self.store.get_aliases_for_room(room_id) - - yield self.store.update_aliases_for_room( - room_id, new_room_id, requester_user_id - ) - - return ( - 200, - { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - }, - ) - - -class ResetPasswordRestServlet(RestServlet): - """Post request to allow an administrator reset password for a user. - This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/reset_password/ - @user:to_reset_password?access_token=admin_access_token - JsonBodyToSend: - { - "new_password": "secret" - } - Returns: - 200 OK with empty object if success otherwise an error. - """ - - PATTERNS = historical_admin_path_patterns( - "/reset_password/(?P<target_user_id>[^/]*)" - ) - - def __init__(self, hs): - self.store = hs.get_datastore() - self.hs = hs - self.auth = hs.get_auth() - self._set_password_handler = hs.get_set_password_handler() - - @defer.inlineCallbacks - def on_POST(self, request, target_user_id): - """Post request to allow an administrator reset password for a user. - This needs user to have administrator access in Synapse. - """ - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) - - UserID.from_string(target_user_id) - - params = parse_json_object_from_request(request) - assert_params_in_dict(params, ["new_password"]) - new_password = params["new_password"] - - yield self._set_password_handler.set_password( - target_user_id, new_password, requester - ) - return 200, {} - - -class GetUsersPaginatedRestServlet(RestServlet): - """Get request to get specific number of users from Synapse. - This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/users_paginate/ - @admin:user?access_token=admin_access_token&start=0&limit=10 - Returns: - 200 OK with json object {list[dict[str, Any]], count} or empty object. - """ - - PATTERNS = historical_admin_path_patterns( - "/users_paginate/(?P<target_user_id>[^/]*)" - ) - - def __init__(self, hs): - self.store = hs.get_datastore() - self.hs = hs - self.auth = hs.get_auth() - self.handlers = hs.get_handlers() - - @defer.inlineCallbacks - def on_GET(self, request, target_user_id): - """Get request to get specific number of users from Synapse. - This needs user to have administrator access in Synapse. - """ - yield assert_requester_is_admin(self.auth, request) - - target_user = UserID.from_string(target_user_id) - - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") - - order = "name" # order by name in user table - start = parse_integer(request, "start", required=True) - limit = parse_integer(request, "limit", required=True) - - logger.info("limit: %s, start: %s", limit, start) - - ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) - return 200, ret - - @defer.inlineCallbacks - def on_POST(self, request, target_user_id): - """Post request to get specific number of users from Synapse.. - This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/users_paginate/ - @admin:user?access_token=admin_access_token - JsonBodyToSend: - { - "start": "0", - "limit": "10 - } - Returns: - 200 OK with json object {list[dict[str, Any]], count} or empty object. - """ - yield assert_requester_is_admin(self.auth, request) - UserID.from_string(target_user_id) - - order = "name" # order by name in user table - params = parse_json_object_from_request(request) - assert_params_in_dict(params, ["limit", "start"]) - limit = params["limit"] - start = params["start"] - logger.info("limit: %s, start: %s", limit, start) - - ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) - return 200, ret - - -class SearchUsersRestServlet(RestServlet): - """Get request to search user table for specific users according to - search term. - This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/search_users/ - @admin:user?access_token=admin_access_token&term=alice - Returns: - 200 OK with json object {list[dict[str, Any]], count} or empty object. - """ - - PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)") - - def __init__(self, hs): - self.store = hs.get_datastore() - self.hs = hs - self.auth = hs.get_auth() - self.handlers = hs.get_handlers() - - @defer.inlineCallbacks - def on_GET(self, request, target_user_id): - """Get request to search user table for specific users according to - search term. - This needs user to have a administrator access in Synapse. - """ - yield assert_requester_is_admin(self.auth, request) - - target_user = UserID.from_string(target_user_id) - - # To allow all users to get the users list - # if not is_admin and target_user != auth_user: - # raise AuthError(403, "You are not a server admin") - - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") - - term = parse_string(request, "term", required=True) - logger.info("term: %s ", term) - - ret = yield self.handlers.admin_handler.search_users(term) - return 200, ret - - -class DeleteGroupAdminRestServlet(RestServlet): - """Allows deleting of local groups - """ - - PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)") - - def __init__(self, hs): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - @defer.inlineCallbacks - def on_POST(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) - - if not self.is_mine_id(group_id): - raise SynapseError(400, "Can only delete local groups") - - yield self.group_server.delete_group(group_id, requester.user.to_string()) - return 200, {} - - -class AccountValidityRenewServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/account_validity/validity$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - self.hs = hs - self.account_activity_handler = hs.get_account_validity_handler() - self.auth = hs.get_auth() - - @defer.inlineCallbacks - def on_POST(self, request): - yield assert_requester_is_admin(self.auth, request) - - body = parse_json_object_from_request(request) - - if "user_id" not in body: - raise SynapseError(400, "Missing property 'user_id' in the request body") - - expiration_ts = yield self.account_activity_handler.renew_account_for_user( - body["user_id"], - body.get("expiration_ts"), - not body.get("enable_renewal_emails", True), - ) - - res = {"expiration_ts": expiration_ts} - return 200, res - - ######################################################################################## # # please don't add more servlets here: this file is already long and unwieldy. Put @@ -740,10 +198,18 @@ def register_servlets(hs, http_server): Register all the admin servlets. """ register_servlets_for_client_rest_resource(hs, http_server) + ListRoomRestServlet(hs).register(http_server) + RoomRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UserRestServletV2(hs).register(http_server) + UsersRestServletV2(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): @@ -754,7 +220,6 @@ def register_servlets_for_client_rest_resource(hs, http_server): PurgeHistoryRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server) - GetUsersPaginatedRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 5a9b08d3ef..d82eaf5e38 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -15,9 +15,11 @@ import re -from twisted.internet import defer +import twisted.web.server +import synapse.api.auth from synapse.api.errors import AuthError +from synapse.types import UserID def historical_admin_path_patterns(path_regex): @@ -31,7 +33,7 @@ def historical_admin_path_patterns(path_regex): Note that this should only be used for existing endpoints: new ones should just register for the /_synapse/admin path. """ - return list( + return [ re.compile(prefix + path_regex) for prefix in ( "^/_synapse/admin/v1", @@ -39,46 +41,50 @@ def historical_admin_path_patterns(path_regex): "^/_matrix/client/unstable/admin", "^/_matrix/client/r0/admin", ) - ) + ] -@defer.inlineCallbacks -def assert_requester_is_admin(auth, request): - """Verify that the requester is an admin user - - WARNING: MAKE SURE YOU YIELD ON THE RESULT! +def admin_patterns(path_regex: str): + """Returns the list of patterns for an admin endpoint Args: - auth (synapse.api.auth.Auth): - request (twisted.web.server.Request): incoming request + path_regex: The regex string to match. This should NOT have a ^ + as this will be prefixed. Returns: - Deferred + A list of regex patterns. + """ + admin_prefix = "^/_synapse/admin/v1" + patterns = [re.compile(admin_prefix + path_regex)] + return patterns + + +async def assert_requester_is_admin( + auth: synapse.api.auth.Auth, request: twisted.web.server.Request +) -> None: + """Verify that the requester is an admin user + + Args: + auth: api.auth.Auth singleton + request: incoming request Raises: - AuthError if the requester is not an admin + AuthError if the requester is not a server admin """ - requester = yield auth.get_user_by_req(request) - yield assert_user_is_admin(auth, requester.user) + requester = await auth.get_user_by_req(request) + await assert_user_is_admin(auth, requester.user) -@defer.inlineCallbacks -def assert_user_is_admin(auth, user_id): +async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None: """Verify that the given user is an admin user - WARNING: MAKE SURE YOU YIELD ON THE RESULT! - Args: - auth (synapse.api.auth.Auth): - user_id (UserID): - - Returns: - Deferred + auth: api.auth.Auth singleton + user_id: user to check Raises: - AuthError if the user is not an admin + AuthError if the user is not a server admin """ - - is_admin = yield auth.is_server_admin(user_id) + is_admin = await auth.is_server_admin(user_id) if not is_admin: raise AuthError(403, "You are not a server admin") diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py new file mode 100644 index 0000000000..8d32677339 --- /dev/null +++ b/synapse/rest/admin/devices.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from synapse.api.errors import NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.rest.admin._base import assert_requester_is_admin +from synapse.types import UserID + +logger = logging.getLogger(__name__) + + +class DeviceRestServlet(RestServlet): + """ + Get, update or delete the given user's device + """ + + PATTERNS = ( + re.compile( + "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$" + ), + ) + + def __init__(self, hs): + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + device = await self.device_handler.get_device( + target_user.to_string(), device_id + ) + return 200, device + + async def on_DELETE(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + await self.device_handler.delete_device(target_user.to_string(), device_id) + return 200, {} + + async def on_PUT(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + body = parse_json_object_from_request(request, allow_empty_body=True) + await self.device_handler.update_device( + target_user.to_string(), device_id, body + ) + return 200, {} + + +class DevicesRestServlet(RestServlet): + """ + Retrieve the given user's devices + """ + + PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + devices = await self.device_handler.get_devices_by_user(target_user.to_string()) + return 200, {"devices": devices} + + +class DeleteDevicesRestServlet(RestServlet): + """ + API for bulk deletion of devices. Accepts a JSON object with a devices + key which lists the device_ids to delete. + """ + + PATTERNS = ( + re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"), + ) + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + body = parse_json_object_from_request(request, allow_empty_body=False) + assert_params_in_dict(body, ["devices"]) + + await self.device_handler.delete_devices( + target_user.to_string(), body["devices"] + ) + return 200, {} diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py new file mode 100644 index 0000000000..0b54ca09f4 --- /dev/null +++ b/synapse/rest/admin/groups.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet +from synapse.rest.admin._base import ( + assert_user_is_admin, + historical_admin_path_patterns, +) + +logger = logging.getLogger(__name__) + + +class DeleteGroupAdminRestServlet(RestServlet): + """Allows deleting of local groups + """ + + PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)") + + def __init__(self, hs): + self.group_server = hs.get_groups_server_handler() + self.is_mine_id = hs.is_mine_id + self.auth = hs.get_auth() + + async def on_POST(self, request, group_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine_id(group_id): + raise SynapseError(400, "Can only delete local groups") + + await self.group_server.delete_group(group_id, requester.user.to_string()) + return 200, {} diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index ed7086d09c..ee75095c0e 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_integer from synapse.rest.admin._base import ( @@ -34,24 +32,85 @@ class QuarantineMediaInRoom(RestServlet): this server. """ - PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)") + PATTERNS = ( + historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine") + + + # This path kept around for legacy reasons + historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)") + ) def __init__(self, hs): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) + async def on_POST(self, request, room_id: str): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Quarantining room: %s", room_id) - num_quarantined = yield self.store.quarantine_media_ids_in_room( + # Quarantine all media in this room + num_quarantined = await self.store.quarantine_media_ids_in_room( room_id, requester.user.to_string() ) return 200, {"num_quarantined": num_quarantined} +class QuarantineMediaByUser(RestServlet): + """Quarantines all local media by a given user so that no one can download it via + this server. + """ + + PATTERNS = historical_admin_path_patterns( + "/user/(?P<user_id>[^/]+)/media/quarantine" + ) + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id: str): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Quarantining local media by user: %s", user_id) + + # Quarantine all media this user has uploaded + num_quarantined = await self.store.quarantine_media_ids_by_user( + user_id, requester.user.to_string() + ) + + return 200, {"num_quarantined": num_quarantined} + + +class QuarantineMediaByID(RestServlet): + """Quarantines local or remote media by a given ID so that no one can download + it via this server. + """ + + PATTERNS = historical_admin_path_patterns( + "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" + ) + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, server_name: str, media_id: str): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Quarantining local media by ID: %s/%s", server_name, media_id) + + # Quarantine this media id + await self.store.quarantine_media_by_id( + server_name, media_id, requester.user.to_string() + ) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ @@ -62,14 +121,13 @@ class ListMediaInRoom(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(requester.user) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: raise AuthError(403, "You are not a server admin") - local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id) + local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) return 200, {"local": local_mxcs, "remote": remote_mxcs} @@ -81,14 +139,13 @@ class PurgeMediaCacheRestServlet(RestServlet): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - yield assert_requester_is_admin(self.auth, request) + async def on_POST(self, request): + await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) logger.info("before_ts: %r", before_ts) - ret = yield self.media_repository.delete_old_remote_media(before_ts) + ret = await self.media_repository.delete_old_remote_media(before_ts) return 200, ret @@ -99,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server): """ PurgeMediaCacheRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) + QuarantineMediaByID(hs).register(http_server) + QuarantineMediaByUser(hs).register(http_server) ListMediaInRoom(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py new file mode 100644 index 0000000000..8173baef8f --- /dev/null +++ b/synapse/rest/admin/rooms.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List, Optional + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.rest.admin._base import ( + admin_patterns, + assert_requester_is_admin, + assert_user_is_admin, + historical_admin_path_patterns, +) +from synapse.storage.data_stores.main.room import RoomSortOrder +from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.util.async_helpers import maybe_awaitable + +logger = logging.getLogger(__name__) + + +class ShutdownRoomRestServlet(RestServlet): + """Shuts down a room by removing all local users from the room and blocking + all future invites and joins to the room. Any local aliases will be repointed + to a new room created by `new_room_user_id` and kicked users will be auto + joined to the new room. + """ + + PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)") + + DEFAULT_MESSAGE = ( + "Sharing illegal content on this server is not permitted and rooms in" + " violation will be blocked." + ) + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self._room_creation_handler = hs.get_room_creation_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + self._replication = hs.get_replication_data_handler() + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + content = parse_json_object_from_request(request) + assert_params_in_dict(content, ["new_room_user_id"]) + new_room_user_id = content["new_room_user_id"] + + room_creator_requester = create_requester(new_room_user_id) + + message = content.get("message", self.DEFAULT_MESSAGE) + room_name = content.get("room_name", "Content Violation Notification") + + info, stream_id = await self._room_creation_handler.create_room( + room_creator_requester, + config={ + "preset": "public_chat", + "name": room_name, + "power_level_content_override": {"users_default": -10}, + }, + ratelimit=False, + ) + new_room_id = info["room_id"] + + requester_user_id = requester.user.to_string() + + logger.info( + "Shutting down room %r, joining to new room: %r", room_id, new_room_id + ) + + # This will work even if the room is already blocked, but that is + # desirable in case the first attempt at blocking the room failed below. + await self.store.block_room(room_id, requester_user_id) + + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + + users = await self.state.get_current_users_in_room(room_id) + kicked_users = [] + failed_to_kick_users = [] + for user_id in users: + if not self.hs.is_mine_id(user_id): + continue + + logger.info("Kicking %r from %r...", user_id, room_id) + + try: + target_requester = create_requester(user_id) + _, stream_id = await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=room_id, + action=Membership.LEAVE, + content={}, + ratelimit=False, + require_consent=False, + ) + + # Wait for leave to come in over replication before trying to forget. + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + + await self.room_member_handler.forget(target_requester.user, room_id) + + await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=new_room_id, + action=Membership.JOIN, + content={}, + ratelimit=False, + require_consent=False, + ) + + kicked_users.append(user_id) + except Exception: + logger.exception( + "Failed to leave old room and join new room for %r", user_id + ) + failed_to_kick_users.append(user_id) + + await self.event_creation_handler.create_and_send_nonmember_event( + room_creator_requester, + { + "type": "m.room.message", + "content": {"body": message, "msgtype": "m.text"}, + "room_id": new_room_id, + "sender": new_room_user_id, + }, + ratelimit=False, + ) + + aliases_for_room = await maybe_awaitable( + self.store.get_aliases_for_room(room_id) + ) + + await self.store.update_aliases_for_room( + room_id, new_room_id, requester_user_id + ) + + return ( + 200, + { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + }, + ) + + +class ListRoomRestServlet(RestServlet): + """ + List all rooms that are known to the homeserver. Results are returned + in a dictionary containing room information. Supports pagination. + """ + + PATTERNS = admin_patterns("/rooms$") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.admin_handler = hs.get_handlers().admin_handler + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + # Extract query parameters + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) + if order_by not in ( + RoomSortOrder.ALPHABETICAL.value, + RoomSortOrder.SIZE.value, + RoomSortOrder.NAME.value, + RoomSortOrder.CANONICAL_ALIAS.value, + RoomSortOrder.JOINED_MEMBERS.value, + RoomSortOrder.JOINED_LOCAL_MEMBERS.value, + RoomSortOrder.VERSION.value, + RoomSortOrder.CREATOR.value, + RoomSortOrder.ENCRYPTION.value, + RoomSortOrder.FEDERATABLE.value, + RoomSortOrder.PUBLIC.value, + RoomSortOrder.JOIN_RULES.value, + RoomSortOrder.GUEST_ACCESS.value, + RoomSortOrder.HISTORY_VISIBILITY.value, + RoomSortOrder.STATE_EVENTS.value, + ): + raise SynapseError( + 400, + "Unknown value for order_by: %s" % (order_by,), + errcode=Codes.INVALID_PARAM, + ) + + search_term = parse_string(request, "search_term") + if search_term == "": + raise SynapseError( + 400, + "search_term cannot be an empty string", + errcode=Codes.INVALID_PARAM, + ) + + direction = parse_string(request, "dir", default="f") + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + reverse_order = True if direction == "b" else False + + # Return list of rooms according to parameters + rooms, total_rooms = await self.store.get_rooms_paginate( + start, limit, order_by, reverse_order, search_term + ) + response = { + # next_token should be opaque, so return a value the client can parse + "offset": start, + "rooms": rooms, + "total_rooms": total_rooms, + } + + # Are there more rooms to paginate through after this? + if (start + limit) < total_rooms: + # There are. Calculate where the query should start from next time + # to get the next part of the list + response["next_batch"] = start + limit + + # Is it possible to paginate backwards? Check if we currently have an + # offset + if start > 0: + if start > limit: + # Going back one iteration won't take us to the start. + # Calculate new offset + response["prev_batch"] = start - limit + else: + response["prev_batch"] = 0 + + return 200, response + + +class RoomRestServlet(RestServlet): + """Get room details. + + TODO: Add on_POST to allow room creation without joining the room + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room_with_stats(room_id) + if not ret: + raise NotFoundError("Room not found") + + return 200, ret + + +class JoinRoomAliasServlet(RestServlet): + + PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.admin_handler = hs.get_handlers().admin_handler + self.state_handler = hs.get_state_handler() + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + content = parse_json_object_from_request(request) + + assert_params_in_dict(content, ["user_id"]) + target_user = UserID.from_string(content["user_id"]) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "This endpoint can only be used with local users") + + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + try: + remote_room_hosts = [ + x.decode("ascii") for x in request.args[b"server_name"] + ] # type: Optional[List[str]] + except Exception: + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + fake_requester = create_requester(target_user) + + # send invite if room has "JoinRules.INVITE" + room_state = await self.state_handler.get_current_state(room_id) + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + await self.room_member_handler.update_membership( + requester=requester, + target=fake_requester.user, + room_id=room_id, + action="invite", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + await self.room_member_handler.update_membership( + requester=fake_requester, + target=fake_requester.user, + room_id=room_id, + action="join", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + return 200, {"room_id": room_id} diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index ae2cbe2e0a..6e9a874121 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -14,8 +14,6 @@ # limitations under the License. import re -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -69,9 +67,8 @@ class SendServerNoticeServlet(RestServlet): self.__class__.__name__, ) - @defer.inlineCallbacks - def on_POST(self, request, txn_id=None): - yield assert_requester_is_admin(self.auth, request) + async def on_POST(self, request, txn_id=None): + await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ("user_id", "content")) event_type = body.get("type", EventTypes.Message) @@ -85,7 +82,7 @@ class SendServerNoticeServlet(RestServlet): if not self.hs.is_mine_id(user_id): raise SynapseError(400, "Server notices can only be sent to local users") - event = yield self.snm.send_notice( + event = await self.snm.send_notice( user_id=body["user_id"], type=event_type, state_key=state_key, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 9720a3bab0..fefc8f71fa 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -12,19 +12,607 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import hmac +import logging import re -from twisted.internet import defer +from six import text_type +from six.moves import http_client -from synapse.api.errors import SynapseError +from synapse.api.constants import UserTypes +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, + parse_integer, parse_json_object_from_request, + parse_string, +) +from synapse.rest.admin._base import ( + assert_requester_is_admin, + assert_user_is_admin, + historical_admin_path_patterns, ) -from synapse.rest.admin import assert_requester_is_admin, assert_user_is_admin from synapse.types import UserID +logger = logging.getLogger(__name__) + + +class UsersRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$") + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.admin_handler = hs.get_handlers().admin_handler + + async def on_GET(self, request, user_id): + target_user = UserID.from_string(user_id) + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + ret = await self.store.get_users() + + return 200, ret + + +class UsersRestServletV2(RestServlet): + PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),) + + """Get request to list all local users. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v2/users?from=0&limit=10&guests=false + + returns: + 200 OK with list of users if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `user_id` can be used to filter by user id. + The parameter `guests` can be used to exclude guest users. + The parameter `deactivated` can be used to include deactivated users. + """ + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.admin_handler = hs.get_handlers().admin_handler + + async def on_GET(self, request): + await assert_requester_is_admin(self.auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + user_id = parse_string(request, "user_id", default=None) + guests = parse_boolean(request, "guests", default=True) + deactivated = parse_boolean(request, "deactivated", default=False) + + users, total = await self.store.get_users_paginate( + start, limit, user_id, guests, deactivated + ) + ret = {"users": users, "total": total} + if len(users) >= limit: + ret["next_token"] = str(start + len(users)) + + return 200, ret + + +class UserRestServletV2(RestServlet): + PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),) + + """Get request to list user details. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v2/users/<user_id> + + returns: + 200 OK with user details if success otherwise an error. + + Put request to allow an administrator to add or modify a user. + This needs user to have administrator access in Synapse. + We use PUT instead of POST since we already know the id of the user + object to create. POST could be used to create guests. + + PUT /_synapse/admin/v2/users/<user_id> + { + "password": "secret", + "displayname": "User" + } + + returns: + 201 OK with new user object if user was created or + 200 OK with modified user object if user was modified + otherwise an error. + """ + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.admin_handler = hs.get_handlers().admin_handler + self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() + self.profile_handler = hs.get_profile_handler() + self.set_password_handler = hs.get_set_password_handler() + self.deactivate_account_handler = hs.get_deactivate_account_handler() + self.registration_handler = hs.get_registration_handler() + self.pusher_pool = hs.get_pusherpool() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + ret = await self.admin_handler.get_user(target_user) + + if not ret: + raise NotFoundError("User not found") + + return 200, ret + + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + target_user = UserID.from_string(user_id) + body = parse_json_object_from_request(request) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "This endpoint can only be used with local users") + + user = await self.admin_handler.get_user(target_user) + user_id = target_user.to_string() + + if user: # modify user + if "displayname" in body: + await self.profile_handler.set_displayname( + target_user, requester, body["displayname"], True + ) + + if "threepids" in body: + # check for required parameters for each threepid + for threepid in body["threepids"]: + assert_params_in_dict(threepid, ["medium", "address"]) + + # remove old threepids from user + threepids = await self.store.user_get_threepids(user_id) + for threepid in threepids: + try: + await self.auth_handler.delete_threepid( + user_id, threepid["medium"], threepid["address"], None + ) + except Exception: + logger.exception("Failed to remove threepids") + raise SynapseError(500, "Failed to remove threepids") + + # add new threepids to user + current_time = self.hs.get_clock().time_msec() + for threepid in body["threepids"]: + await self.auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], current_time + ) + + if "avatar_url" in body and type(body["avatar_url"]) == str: + await self.profile_handler.set_avatar_url( + target_user, requester, body["avatar_url"], True + ) + + if "admin" in body: + set_admin_to = bool(body["admin"]) + if set_admin_to != user["admin"]: + auth_user = requester.user + if target_user == auth_user and not set_admin_to: + raise SynapseError(400, "You may not demote yourself.") + + await self.store.set_server_admin(target_user, set_admin_to) + + if "password" in body: + if ( + not isinstance(body["password"], text_type) + or len(body["password"]) > 512 + ): + raise SynapseError(400, "Invalid password") + else: + new_password = body["password"] + logout_devices = True + + new_password_hash = await self.auth_handler.hash(new_password) + + await self.set_password_handler.set_password( + target_user.to_string(), + new_password_hash, + logout_devices, + requester, + ) + + if "deactivated" in body: + deactivate = body["deactivated"] + if not isinstance(deactivate, bool): + raise SynapseError( + 400, "'deactivated' parameter is not of type boolean" + ) + + if deactivate and not user["deactivated"]: + await self.deactivate_account_handler.deactivate_account( + target_user.to_string(), False + ) + + user = await self.admin_handler.get_user(target_user) + return 200, user + + else: # create user + password = body.get("password") + password_hash = None + if password is not None: + if not isinstance(password, text_type) or len(password) > 512: + raise SynapseError(400, "Invalid password") + password_hash = await self.auth_handler.hash(password) + + admin = body.get("admin", None) + user_type = body.get("user_type", None) + displayname = body.get("displayname", None) + threepids = body.get("threepids", None) + + if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + raise SynapseError(400, "Invalid user type") + + user_id = await self.registration_handler.register_user( + localpart=target_user.localpart, + password_hash=password_hash, + admin=bool(admin), + default_display_name=displayname, + user_type=user_type, + by_admin=True, + ) + + if "threepids" in body: + # check for required parameters for each threepid + for threepid in body["threepids"]: + assert_params_in_dict(threepid, ["medium", "address"]) + + current_time = self.hs.get_clock().time_msec() + for threepid in body["threepids"]: + await self.auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], current_time + ) + if ( + self.hs.config.email_enable_notifs + and self.hs.config.email_notif_for_new_users + ): + await self.pusher_pool.add_pusher( + user_id=user_id, + access_token=None, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) + + if "avatar_url" in body and type(body["avatar_url"]) == str: + await self.profile_handler.set_avatar_url( + user_id, requester, body["avatar_url"], True + ) + + ret = await self.admin_handler.get_user(target_user) + + return 201, ret + + +class UserRegisterServlet(RestServlet): + """ + Attributes: + NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted + nonces (dict[str, int]): The nonces that we will accept. A dict of + nonce to the time it was generated, in int seconds. + """ + + PATTERNS = historical_admin_path_patterns("/register") + NONCE_TIMEOUT = 60 + + def __init__(self, hs): + self.auth_handler = hs.get_auth_handler() + self.reactor = hs.get_reactor() + self.nonces = {} + self.hs = hs + + def _clear_old_nonces(self): + """ + Clear out old nonces that are older than NONCE_TIMEOUT. + """ + now = int(self.reactor.seconds()) + + for k, v in list(self.nonces.items()): + if now - v > self.NONCE_TIMEOUT: + del self.nonces[k] + + def on_GET(self, request): + """ + Generate a new nonce. + """ + self._clear_old_nonces() + + nonce = self.hs.get_secrets().token_hex(64) + self.nonces[nonce] = int(self.reactor.seconds()) + return 200, {"nonce": nonce} + + async def on_POST(self, request): + self._clear_old_nonces() + + if not self.hs.config.registration_shared_secret: + raise SynapseError(400, "Shared secret registration is not enabled") + + body = parse_json_object_from_request(request) + + if "nonce" not in body: + raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) + + nonce = body["nonce"] + + if nonce not in self.nonces: + raise SynapseError(400, "unrecognised nonce") + + # Delete the nonce, so it can't be reused, even if it's invalid + del self.nonces[nonce] + + if "username" not in body: + raise SynapseError( + 400, "username must be specified", errcode=Codes.BAD_JSON + ) + else: + if ( + not isinstance(body["username"], text_type) + or len(body["username"]) > 512 + ): + raise SynapseError(400, "Invalid username") + + username = body["username"].encode("utf-8") + if b"\x00" in username: + raise SynapseError(400, "Invalid username") + + if "password" not in body: + raise SynapseError( + 400, "password must be specified", errcode=Codes.BAD_JSON + ) + else: + password = body["password"] + if not isinstance(password, text_type) or len(password) > 512: + raise SynapseError(400, "Invalid password") + + password_bytes = password.encode("utf-8") + if b"\x00" in password_bytes: + raise SynapseError(400, "Invalid password") + + password_hash = await self.auth_handler.hash(password) + + admin = body.get("admin", None) + user_type = body.get("user_type", None) + + if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + raise SynapseError(400, "Invalid user type") + + got_mac = body["mac"] + + want_mac_builder = hmac.new( + key=self.hs.config.registration_shared_secret.encode(), + digestmod=hashlib.sha1, + ) + want_mac_builder.update(nonce.encode("utf8")) + want_mac_builder.update(b"\x00") + want_mac_builder.update(username) + want_mac_builder.update(b"\x00") + want_mac_builder.update(password_bytes) + want_mac_builder.update(b"\x00") + want_mac_builder.update(b"admin" if admin else b"notadmin") + if user_type: + want_mac_builder.update(b"\x00") + want_mac_builder.update(user_type.encode("utf8")) + + want_mac = want_mac_builder.hexdigest() + + if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): + raise SynapseError(403, "HMAC incorrect") + + # Reuse the parts of RegisterRestServlet to reduce code duplication + from synapse.rest.client.v2_alpha.register import RegisterRestServlet + + register = RegisterRestServlet(self.hs) + + user_id = await register.registration_handler.register_user( + localpart=body["username"].lower(), + password_hash=password_hash, + admin=bool(admin), + user_type=user_type, + by_admin=True, + ) + + result = await register._create_registration_details(user_id, body) + return 200, result + + +class WhoisRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.handlers = hs.get_handlers() + + async def on_GET(self, request, user_id): + target_user = UserID.from_string(user_id) + requester = await self.auth.get_user_by_req(request) + auth_user = requester.user + + if target_user != auth_user: + await assert_user_is_admin(self.auth, auth_user) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only whois a local user") + + ret = await self.handlers.admin_handler.get_whois(target_user) + + return 200, ret + + +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self._deactivate_account_handler = hs.get_deactivate_account_handler() + self.auth = hs.get_auth() + + async def on_POST(self, request, target_user_id): + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request, allow_empty_body=True) + erase = body.get("erase", False) + if not isinstance(erase, bool): + raise SynapseError( + http_client.BAD_REQUEST, + "Param 'erase' must be a boolean, if given", + Codes.BAD_JSON, + ) + + UserID.from_string(target_user_id) + + result = await self._deactivate_account_handler.deactivate_account( + target_user_id, erase + ) + if result: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + return 200, {"id_server_unbind_result": id_server_unbind_result} + + +class AccountValidityRenewServlet(RestServlet): + PATTERNS = historical_admin_path_patterns("/account_validity/validity$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + + async def on_POST(self, request): + await assert_requester_is_admin(self.auth, request) + + body = parse_json_object_from_request(request) + + if "user_id" not in body: + raise SynapseError(400, "Missing property 'user_id' in the request body") + + expiration_ts = await self.account_activity_handler.renew_account_for_user( + body["user_id"], + body.get("expiration_ts"), + not body.get("enable_renewal_emails", True), + ) + + res = {"expiration_ts": expiration_ts} + return 200, res + + +class ResetPasswordRestServlet(RestServlet): + """Post request to allow an administrator reset password for a user. + This needs user to have administrator access in Synapse. + Example: + http://localhost:8008/_synapse/admin/v1/reset_password/ + @user:to_reset_password?access_token=admin_access_token + JsonBodyToSend: + { + "new_password": "secret" + } + Returns: + 200 OK with empty object if success otherwise an error. + """ + + PATTERNS = historical_admin_path_patterns( + "/reset_password/(?P<target_user_id>[^/]*)" + ) + + def __init__(self, hs): + self.store = hs.get_datastore() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self._set_password_handler = hs.get_set_password_handler() + + async def on_POST(self, request, target_user_id): + """Post request to allow an administrator reset password for a user. + This needs user to have administrator access in Synapse. + """ + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + UserID.from_string(target_user_id) + + params = parse_json_object_from_request(request) + assert_params_in_dict(params, ["new_password"]) + new_password = params["new_password"] + logout_devices = params.get("logout_devices", True) + + new_password_hash = await self.auth_handler.hash(new_password) + + await self._set_password_handler.set_password( + target_user_id, new_password_hash, logout_devices, requester + ) + return 200, {} + + +class SearchUsersRestServlet(RestServlet): + """Get request to search user table for specific users according to + search term. + This needs user to have administrator access in Synapse. + Example: + http://localhost:8008/_synapse/admin/v1/search_users/ + @admin:user?access_token=admin_access_token&term=alice + Returns: + 200 OK with json object {list[dict[str, Any]], count} or empty object. + """ + + PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.handlers = hs.get_handlers() + + async def on_GET(self, request, target_user_id): + """Get request to search user table for specific users according to + search term. + This needs user to have a administrator access in Synapse. + """ + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(target_user_id) + + # To allow all users to get the users list + # if not is_admin and target_user != auth_user: + # raise AuthError(403, "You are not a server admin") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only users a local user") + + term = parse_string(request, "term", required=True) + logger.info("term: %s ", term) + + ret = await self.handlers.store.search_users(term) + return 200, ret + class UserAdminServlet(RestServlet): """ @@ -52,31 +640,28 @@ class UserAdminServlet(RestServlet): {} """ - PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>@[^/]*)/admin$"),) + PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),) def __init__(self, hs): self.hs = hs + self.store = hs.get_datastore() self.auth = hs.get_auth() - self.handlers = hs.get_handlers() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - yield assert_requester_is_admin(self.auth, request) + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): raise SynapseError(400, "Only local users can be admins of this homeserver") - is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user) - is_admin = bool(is_admin) + is_admin = await self.store.is_server_admin(target_user) return 200, {"admin": is_admin} - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) - yield assert_user_is_admin(self.auth, requester.user) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user target_user = UserID.from_string(user_id) @@ -93,8 +678,6 @@ class UserAdminServlet(RestServlet): if target_user == auth_user and not set_admin_to: raise SynapseError(400, "You may not demote yourself.") - yield self.handlers.admin_handler.set_user_server_admin( - target_user, set_admin_to - ) + await self.store.set_server_admin(target_user, set_admin_to) return 200, {} diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4ea3666874..5934b1fe8b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( AuthError, Codes, @@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_alias): + async def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) dir_handler = self.handlers.directory_handler - res = yield dir_handler.get_association(room_alias) + res = await dir_handler.get_association(room_alias) return 200, res - @defer.inlineCallbacks - def on_PUT(self, request, room_alias): + async def on_PUT(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) content = parse_json_object_from_request(request) @@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet): # TODO(erikj): Check types. - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Room does not exist") - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.create_association( + await self.handlers.directory_handler.create_association( requester, room_alias, room_id, servers ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_alias): + async def on_DELETE(self, request, room_alias): dir_handler = self.handlers.directory_handler try: - service = yield self.auth.get_appservice_by_req(request) + service = await self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association(service, room_alias) + await dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, @@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet): # fallback to default user behaviour if they aren't an AS pass - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user = requester.user room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association(requester, room_alias) + await dir_handler.delete_association(requester, room_alias) logger.info( "User %s deleted alias %s", user.to_string(), room_alias.to_string() @@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - room = yield self.store.get_room(room_id) + async def on_GET(self, request, room_id): + room = await self.store.get_room(room_id) if room is None: raise NotFoundError("Unknown room") return 200, {"visibility": "public" if room["is_public"] else "private"} - @defer.inlineCallbacks - def on_PUT(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, visibility ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, "private" ) @@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet): def on_DELETE(self, request, network_id, room_id): return self._edit(request, network_id, room_id, "private") - @defer.inlineCallbacks - def _edit(self, request, network_id, room_id, visibility): - requester = yield self.auth.get_user_by_req(request) + async def _edit(self, request, network_id, room_id, visibility): + requester = await self.auth.get_user_by_req(request) if not requester.app_service: raise AuthError( 403, "Only appservices can edit the appservice published room list" ) - yield self.handlers.directory_handler.edit_published_appservice_room_list( + await self.handlers.directory_handler.edit_published_appservice_room_list( requester.app_service.id, network_id, room_id, visibility ) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 6651b4cf07..25effd0261 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -16,8 +16,6 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet): self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: @@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in request.args - chunk = yield self.event_stream_handler.get_stream( + chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -73,7 +70,6 @@ class EventStreamRestServlet(RestServlet): return 200, {} -# TODO: Unit test gets, with and without auth, with different kinds of events. class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) @@ -81,16 +77,16 @@ class EventRestServlet(RestServlet): super(EventRestServlet, self).__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() + self.auth = hs.get_auth() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request, event_id): - requester = yield self.auth.get_user_by_req(request) - event = yield self.event_handler.get_event(requester.user, None, event_id) + async def on_GET(self, request, event_id): + requester = await self.auth.get_user_by_req(request) + event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event else: return 404, "Event not found." diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 2da3cd7511..910b3b4eeb 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_boolean from synapse.rest.client.v2_alpha._base import client_patterns @@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) - content = yield self.initial_sync_handler.snapshot_all_rooms( + content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, as_client_event=as_client_event, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 25a1b67092..dceb2792fa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,12 +14,6 @@ # limitations under the License. import logging -import xml.etree.ElementTree as ET - -from six.moves import urllib - -from twisted.internet import defer -from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -29,9 +23,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn logger = logging.getLogger(__name__) @@ -88,29 +83,42 @@ class LoginRestServlet(RestServlet): self.jwt_algorithm = hs.config.jwt_algorithm self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled + self.oidc_enabled = hs.config.oidc_enabled self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter() + self._address_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + self._failed_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) def on_GET(self, request): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) - if self.saml2_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - if self.cas_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) + if self.cas_enabled: # we advertise CAS for backwards compat, though MSC1721 renamed it # to SSO. flows.append({"type": LoginRestServlet.CAS_TYPE}) + if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: + flows.append({"type": LoginRestServlet.SSO_TYPE}) # While its valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the - # CAS login flow. + # SSO login flow. # Generally we don't want to advertise login flows that clients # don't know how to implement, since they (currently) will always # fall back to the fallback API if they don't understand one of the @@ -126,26 +134,19 @@ class LoginRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - self._address_ratelimiter.ratelimit( - request.getClientIP(), - time_now_s=self.hs.clock.time(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, - update=True, - ) + async def on_POST(self, request): + self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE ): - result = yield self.do_jwt_login(login_submission) + result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - result = yield self.do_token_login(login_submission) + result = await self.do_token_login(login_submission) else: - result = yield self._do_other_login(login_submission) + result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -154,8 +155,7 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - @defer.inlineCallbacks - def _do_other_login(self, login_submission): + async def _do_other_login(self, login_submission): """Handle non-token/saml/jwt logins Args: @@ -201,28 +201,43 @@ class LoginRestServlet(RestServlet): # (See add_threepid in synapse/handlers/auth.py) address = address.lower() + # We also apply account rate limiting using the 3PID as a key, as + # otherwise using 3PID bypasses the ratelimiting based on user ID. + self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) + # Check for login providers that support 3pid login types - canonical_user_id, callback_3pid = ( - yield self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) + ( + canonical_user_id, + callback_3pid, + ) = await self.auth_handler.check_password_provider_3pid( + medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded - result = yield self._register_device_with_callback( + + result = await self._complete_login( canonical_user_id, login_submission, callback_3pid ) return result # No password providers were able to handle this 3pid # Check local store - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastore().get_user_id_by_threepid( medium, address ) if not user_id: - logger.warn( + logger.warning( "unknown 3pid identifier medium %s, address %r", medium, address ) + # We mark that we've failed to log in here, as + # `check_password_provider_3pid` might have returned `None` due + # to an incorrect password, rather than the account not + # existing. + # + # If it returned None but the 3PID was bound then we won't hit + # this code path, which is fine as then the per-user ratelimit + # will kick in below. + self._failed_attempts_ratelimiter.can_do_action((medium, address)) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -234,32 +249,71 @@ class LoginRestServlet(RestServlet): if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") - canonical_user_id, callback = yield self.auth_handler.validate_login( - identifier["user"], login_submission + if identifier["user"].startswith("@"): + qualified_user_id = identifier["user"] + else: + qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() + + # Check if we've hit the failed ratelimit (but don't update it) + self._failed_attempts_ratelimiter.ratelimit( + qualified_user_id.lower(), update=False ) - result = yield self._register_device_with_callback( + try: + canonical_user_id, callback = await self.auth_handler.validate_login( + identifier["user"], login_submission + ) + except LoginError: + # The user has failed to log in, so we need to update the rate + # limiter. Using `can_do_action` avoids us raising a ratelimit + # exception and masking the LoginError. The actual ratelimiting + # should have happened above. + self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) + raise + + result = await self._complete_login( canonical_user_id, login_submission, callback ) return result - @defer.inlineCallbacks - def _register_device_with_callback(self, user_id, login_submission, callback=None): - """ Registers a device with a given user_id. Optionally run a callback - function after registration has completed. + async def _complete_login( + self, user_id, login_submission, callback=None, create_non_existent_users=False + ): + """Called when we've successfully authed the user and now need to + actually login them in (e.g. create devices). This gets called on + all succesful logins. + + Applies the ratelimiting for succesful login attempts against an + account. Args: user_id (str): ID of the user to register. login_submission (dict): Dictionary of login information. callback (func|None): Callback function to run after registration. + create_non_existent_users (bool): Whether to create the user if + they don't exist. Defaults to False. Returns: result (Dict[str,str]): Dictionary of account information after successful registration. """ + + # Before we actually log them in we check if they've already logged in + # too often. This happens here rather than before as we don't + # necessarily know the user before now. + self._account_ratelimiter.ratelimit(user_id.lower()) + + if create_non_existent_users: + canonical_uid = await self.auth_handler.check_user_exists(user_id) + if not canonical_uid: + canonical_uid = await self.registration_handler.register_user( + localpart=UserID.from_string(user_id).localpart + ) + user_id = canonical_uid + device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name ) @@ -271,23 +325,21 @@ class LoginRestServlet(RestServlet): } if callback is not None: - yield callback(result) + await callback(result) return result - @defer.inlineCallbacks - def do_token_login(self, login_submission): + async def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = ( - yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( + token ) - result = yield self._register_device_with_callback(user_id, login_submission) + result = await self._complete_login(user_id, login_submission) return result - @defer.inlineCallbacks - def do_jwt_login(self, login_submission): + async def do_jwt_login(self, login_submission): token = login_submission.get("token", None) if token is None: raise LoginError( @@ -311,15 +363,8 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - - registered_user_id = yield self.auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = yield self.registration_handler.register_user( - localpart=user - ) - - result = yield self._register_device_with_callback( - registered_user_id, login_submission + result = await self._complete_login( + user_id, login_submission, create_non_existent_users=True ) return result @@ -329,24 +374,27 @@ class BaseSSORedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) - def on_GET(self, request): + async def on_GET(self, request: SynapseRequest): args = request.args if b"redirectUrl" not in args: return 400, "Redirect URL not specified for SSO auth" client_redirect_url = args[b"redirectUrl"][0] - sso_url = self.get_sso_url(client_redirect_url) + sso_url = await self.get_sso_url(request, client_redirect_url) request.redirect(sso_url) finish_request(request) - def get_sso_url(self, client_redirect_url): + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: """Get the URL to redirect to, to perform SSO auth Args: - client_redirect_url (bytes): the URL that we should redirect the + request: The client request to redirect. + client_redirect_url: the URL that we should redirect the client to when everything is done Returns: - bytes: URL to redirect to + URL to redirect to """ # to be implemented by subclasses raise NotImplementedError() @@ -354,19 +402,14 @@ class BaseSSORedirectServlet(RestServlet): class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): - super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode("ascii") - self.cas_service_url = hs.config.cas_service_url.encode("ascii") + self._cas_handler = hs.get_cas_handler() - def get_sso_url(self, client_redirect_url): - client_redirect_url_param = urllib.parse.urlencode( - {b"redirectUrl": client_redirect_url} + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: + return self._cas_handler.get_redirect_url( + {"redirectUrl": client_redirect_url} ).encode("ascii") - hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" - service_param = urllib.parse.urlencode( - {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} - ).encode("ascii") - return b"%s/login?%s" % (self.cas_server_url, service_param) class CasTicketServlet(RestServlet): @@ -374,80 +417,25 @@ class CasTicketServlet(RestServlet): def __init__(self, hs): super(CasTicketServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url - self.cas_required_attributes = hs.config.cas_required_attributes - self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_simple_http_client() - - @defer.inlineCallbacks - def on_GET(self, request): - client_redirect_url = parse_string(request, "redirectUrl", required=True) - uri = self.cas_server_url + "/proxyValidate" - args = { - "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url, - } - try: - body = yield self._http_client.get_raw(uri, args) - except PartialDownloadError as pde: - # Twisted raises this error if the connection is closed, - # even if that's being used old-http style to signal end-of-data - body = pde.response - result = yield self.handle_cas_response(request, body, client_redirect_url) - return result + self._cas_handler = hs.get_cas_handler() - def handle_cas_response(self, request, cas_response_body, client_redirect_url): - user, attributes = self.parse_cas_response(cas_response_body) + async def on_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl") + ticket = parse_string(request, "ticket", required=True) - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + # Maybe get a session ID (if this ticket is from user interactive + # authentication). + session = parse_string(request, "session") - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + # Either client_redirect_url or session must be provided. + if not client_redirect_url and not session: + message = "Missing string query parameter redirectUrl or session" + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) - return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url + await self._cas_handler.handle_ticket( + request, ticket, client_redirect_url, session ) - def parse_cas_response(self, cas_response_body): - user = None - attributes = {} - try: - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise Exception("root of CAS response is not serviceResponse") - success = root[0].tag.endswith("authenticationSuccess") - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - for attribute in child: - # ElementTree library expands the namespace in - # attribute tags to the full URL of the namespace. - # We don't care about namespace here and it will always - # be encased in curly braces, so we remove them. - tag = attribute.tag - if "}" in tag: - tag = tag.split("}")[1] - attributes[tag] = attribute.text - if user is None: - raise Exception("CAS response does not contain user") - except Exception: - logger.error("Error parsing CAS response", exc_info=1) - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not success: - raise LoginError( - 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED - ) - return user, attributes - class SAMLRedirectServlet(BaseSSORedirectServlet): PATTERNS = client_patterns("/login/sso/redirect", v1=True) @@ -455,74 +443,26 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._saml_handler = hs.get_saml_handler() - def get_sso_url(self, client_redirect_url): + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: return self._saml_handler.handle_redirect_request(client_redirect_url) -class SSOAuthHandler(object): - """ - Utility class for Resources and Servlets which handle the response from a SSO - service +class OIDCRedirectServlet(BaseSSORedirectServlet): + """Implementation for /login/sso/redirect for the OIDC login flow.""" - Args: - hs (synapse.server.HomeServer) - """ + PATTERNS = client_patterns("/login/sso/redirect", v1=True) def __init__(self, hs): - self._hostname = hs.hostname - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() - self._macaroon_gen = hs.get_macaroon_generator() - - @defer.inlineCallbacks - def on_successful_auth( - self, username, request, client_redirect_url, user_display_name=None - ): - """Called once the user has successfully authenticated with the SSO. - - Registers the user if necessary, and then returns a redirect (with - a login token) to the client. - - Args: - username (unicode|bytes): the remote user id. We'll map this onto - something sane for a MXID localpath. - - request (SynapseRequest): the incoming request from the browser. We'll - respond to it with a redirect. - - client_redirect_url (unicode): the redirect_url the client gave us when - it first started the process. - - user_display_name (unicode|None): if set, and we have to register a new user, - we will set their displayname to this. + self._oidc_handler = hs.get_oidc_handler() - Returns: - Deferred[none]: Completes once we have handled the request. - """ - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = yield self._auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = yield self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name - ) - - login_token = self._macaroon_gen.generate_short_term_login_token( - registered_user_id + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: + return await self._oidc_handler.handle_redirect_request( + request, client_redirect_url ) - redirect_url = self._add_login_token_to_redirect_url( - client_redirect_url, login_token - ) - request.redirect(redirect_url) - finish_request(request) - - @staticmethod - def _add_login_token_to_redirect_url(url, token): - url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({"loginToken": token}) - url_parts[4] = urllib.parse.urlencode(query) - return urllib.parse.urlunparse(url_parts) def register_servlets(hs, http_server): @@ -532,3 +472,5 @@ def register_servlets(hs, http_server): CasTicketServlet(hs).register(http_server) elif hs.config.saml2_enabled: SAMLRedirectServlet(hs).register(http_server) + elif hs.config.oidc_enabled: + OIDCRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 4785a34d75..b0c30b65be 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) if requester.device_id is None: - # the acccess token wasn't associated with a device. + # The access token wasn't associated with a device. # Just delete the access token access_token = self.auth.get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) + await self._auth_handler.delete_access_token(access_token) else: - yield self._device_handler.delete_device( + await self._device_handler.delete_device( requester.user.to_string(), requester.device_id ) @@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() # first delete all of the user's devices - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # .. and then delete any access tokens which weren't associated with # devices. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) return 200, {} diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 0153525cef..eec16f8ad8 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,8 +19,6 @@ import logging from six import string_types -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.presence_handler.is_visible( + allowed = await self.presence_handler.is_visible( observed_user=user, observer_user=requester.user ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.presence_handler.get_state(target_user=user) + state = await self.presence_handler.get_state(target_user=user) state = format_user_presence_state(state, self.clock.time_msec()) return 200, state - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: @@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet): raise SynapseError(400, "Unable to parse state") if self.hs.config.use_presence: - yield self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index bbce2e2b71..e7fe50ed72 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -14,8 +14,8 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/<paths> """ -from twisted.internet import defer +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import UserID @@ -30,19 +30,18 @@ class ProfileDisplaynameRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) + displayname = await self.profile_handler.get_displayname(user) ret = {} if displayname is not None: @@ -50,11 +49,10 @@ class ProfileDisplaynameRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) @@ -63,7 +61,7 @@ class ProfileDisplaynameRestServlet(RestServlet): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) + await self.profile_handler.set_displayname(user, requester, new_name, is_admin) return 200, {} @@ -80,19 +78,18 @@ class ProfileAvatarURLRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if avatar_url is not None: @@ -100,19 +97,22 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) try: - new_name = content["avatar_url"] - except Exception: - return 400, "Unable to parse name" + new_avatar_url = content["avatar_url"] + except KeyError: + raise SynapseError( + 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM + ) - yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) + await self.profile_handler.set_avatar_url( + user, requester, new_avatar_url, is_admin + ) return 200, {} @@ -129,20 +129,19 @@ class ProfileRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + displayname = await self.profile_handler.get_displayname(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if displayname is not None: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 9f8c3d09e3..9fd4908136 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import ( NotFoundError, @@ -46,18 +45,17 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None - @defer.inlineCallbacks - def on_PUT(self, request, path): + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: raise SynapseError(400, str(e)) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") @@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet): user_id = requester.user.to_string() if "attr" in spec: - yield self.set_rule_attr(user_id, spec, content) + await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} @@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet): after = _namespaced_rule_id(spec, after) try: - yield self.store.add_push_rule( + await self.store.add_push_rule( user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, @@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, path): + async def on_DELETE(self, request, path): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule(user_id, namespaced_rule_id) + await self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) return 200, {} except StoreError as e: @@ -130,19 +127,18 @@ class PushRuleRestServlet(RestServlet): else: raise - @defer.inlineCallbacks - def on_GET(self, request, path): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, path): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(requester.user, rules) - path = [x for x in path.split("/")][1:] + path = path.split("/")[1:] if path == []: # we're a reference impl: pedantry is our job. diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 41660682d9..550a2f1b44 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -30,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) +ALLOWED_KEYS = { + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", +} + class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -39,30 +48,17 @@ class PushersRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - - allowed_keys = [ - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", - ] + pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - for p in pushers: - for k, v in list(p.items()): - if k not in allowed_keys: - del p[k] + filtered_pushers = [ + {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers + ] - return 200, {"pushers": pushers} + return 200, {"pushers": filtered_pushers} def on_OPTIONS(self, _): return 200, {} @@ -78,9 +74,8 @@ class PushersSetRestServlet(RestServlet): self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user content = parse_json_object_from_request(request) @@ -91,7 +86,7 @@ class PushersSetRestServlet(RestServlet): and "kind" in content and content["kind"] is None ): - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( content["app_id"], content["pushkey"], user_id=user.to_string() ) return 200, {} @@ -117,14 +112,14 @@ class PushersSetRestServlet(RestServlet): append = content["append"] if not append: - yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], pushkey=content["pushkey"], not_user_id=user.to_string(), ) try: - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -164,16 +159,15 @@ class PushersRemoveRestServlet(RestServlet): self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) try: - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a6a7b3b57e..105e0cf4d2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,17 +16,18 @@ """ This module contains REST servlets to do with rooms: /rooms/<paths> """ import logging +import re +from typing import List, Optional from six.moves.urllib import parse as urlparse from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, + HttpResponseException, InvalidClientCredentialsError, SynapseError, ) @@ -39,12 +40,17 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.logging.opentracing import set_tag from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -81,13 +87,13 @@ class RoomCreateRestServlet(TransactionRestServlet): ) def on_PUT(self, request, txn_id): + set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) - info = yield self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -152,15 +158,14 @@ class RoomStateEventRestServlet(TransactionRestServlet): def on_PUT_no_state_key(self, request, room_id, event_type): return self.on_PUT(request, room_id, event_type, "") - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_type, state_key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] ) msg_handler = self.message_handler - data = yield msg_handler.get_room_data( + data = await msg_handler.get_room_data( user_id=requester.user.to_string(), room_id=room_id, event_type=event_type, @@ -177,9 +182,11 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) + + if txn_id: + set_tag("txn_id", txn_id) content = parse_json_object_from_request(request) @@ -195,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.room_member_handler.update_membership( + event_id, _ = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -203,13 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = yield self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) + event_id = event.event_id - ret = {} - if event: - ret = {"event_id": event.event_id} + ret = {} # type: dict + if event_id: + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -225,9 +237,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict = { @@ -240,16 +251,19 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) + set_tag("event_id", event.event_id) return 200, {"event_id": event.event_id} def on_GET(self, request, room_id, event_type, txn_id): return 200, "Not implemented" def on_PUT(self, request, room_id, event_type, txn_id): + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id, event_type, txn_id ) @@ -267,9 +281,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_identifier, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -283,20 +296,20 @@ class JoinRoomAliasServlet(TransactionRestServlet): try: remote_room_hosts = [ x.decode("ascii") for x in request.args[b"server_name"] - ] + ] # type: Optional[List[str]] except Exception: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -310,6 +323,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): return 200, {"room_id": room_id} def on_PUT(self, request, room_identifier, txn_id): + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_identifier, txn_id ) @@ -324,12 +339,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): server = parse_string(request, "server", default=None) try: - yield self.auth.get_user_by_req(request, allow_guest=True) + await self.auth.get_user_by_req(request, allow_guest=True) except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private @@ -350,26 +364,32 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = parse_integer(request, "limit", 0) since_token = parse_string(request, "since", None) + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + handler = self.hs.get_room_list_handler() - if server: - data = yield handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) + if server and server != self.hs.config.server_name: + try: + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) + except HttpResponseException as e: + raise e.to_synapse_error() else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token ) return 200, data - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) - limit = int(content.get("limit", 100)) + limit = int(content.get("limit", 100)) # type: Optional[int] since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -387,18 +407,25 @@ class PublicRoomListRestServlet(TransactionRestServlet): else: network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + handler = self.hs.get_room_list_handler() - if server: - data = yield handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + if server and server != self.hs.config.server_name: + try: + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except HttpResponseException as e: + raise e.to_synapse_error() else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token, search_filter=search_filter, @@ -417,10 +444,9 @@ class RoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): + async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -440,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet): membership = parse_string(request, "membership") not_membership = parse_string(request, "not_membership") - events = yield handler.get_state_events( + events = await handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, @@ -469,11 +495,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - users_with_profile = yield self.message_handler.get_joined_members( + users_with_profile = await self.message_handler.get_joined_members( requester, room_id ) @@ -489,20 +514,24 @@ class RoomMessageListRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) - event_filter = Filter(json.loads(filter_json)) - if event_filter.filter_json.get("event_format", "client") == "federation": + event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + if ( + event_filter + and event_filter.filter_json.get("event_format", "client") + == "federation" + ): as_client_event = False else: event_filter = None - msgs = yield self.pagination_handler.get_messages( + + msgs = await self.pagination_handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -522,11 +551,10 @@ class RoomStateRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room - events = yield self.message_handler.get_state_events( + events = await self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), is_guest=requester.is_guest, @@ -543,11 +571,10 @@ class RoomInitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) - content = yield self.initial_sync_handler.room_initial_sync( + content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) return 200, content @@ -565,11 +592,10 @@ class RoomEventServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: - event = yield self.event_handler.get_event( + event = await self.event_handler.get_event( requester.user, room_id, event_id ) except AuthError: @@ -580,7 +606,7 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -598,9 +624,8 @@ class RoomEventContextServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -608,11 +633,11 @@ class RoomEventContextServlet(RestServlet): filter_bytes = parse_string(request, "filter") if filter_bytes: filter_json = urlparse.unquote(filter_bytes) - event_filter = Filter(json.loads(filter_json)) + event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] else: event_filter = None - results = yield self.room_context_handler.get_event_context( + results = await self.room_context_handler.get_event_context( requester.user, room_id, event_id, limit, event_filter ) @@ -620,16 +645,16 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"] = await self._event_serializer.serialize_events( results["events_before"], time_now ) - results["event"] = yield self._event_serializer.serialize_event( + results["event"] = await self._event_serializer.serialize_event( results["event"], time_now ) - results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"] = await self._event_serializer.serialize_events( results["events_after"], time_now ) - results["state"] = yield self._event_serializer.serialize_events( + results["state"] = await self._event_serializer.serialize_events( results["state"], time_now ) @@ -646,15 +671,16 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget(user=requester.user, room_id=room_id) + await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} def on_PUT(self, request, room_id, txn_id): + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id, txn_id ) @@ -675,9 +701,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, @@ -693,7 +718,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.room_member_handler.do_3pid_invite( + await self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -711,10 +736,10 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = UserID.from_string(content["user_id"]) event_content = None - if "reason" in content and membership_action in ["kick", "ban"]: + if "reason" in content: event_content = {"reason": content["reason"]} - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -738,6 +763,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): return True def on_PUT(self, request, room_id, membership_action, txn_id): + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id, membership_action, txn_id ) @@ -754,12 +781,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -771,9 +797,12 @@ class RoomRedactEventRestServlet(TransactionRestServlet): txn_id=txn_id, ) + set_tag("event_id", event.event_id) return 200, {"event_id": event.event_id} def on_PUT(self, request, room_id, event_id, txn_id): + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id, event_id, txn_id ) @@ -790,35 +819,57 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) if content["typing"]: - yield self.typing_handler.started_typing( + await self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=timeout, ) else: - yield self.typing_handler.stopped_typing( + await self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id ) return 200, {} +class RoomAliasListServlet(RestServlet): + PATTERNS = [ + re.compile( + r"^/_matrix/client/unstable/org\.matrix\.msc2432" + r"/rooms/(?P<room_id>[^/]*)/aliases" + ), + ] + + def __init__(self, hs: "synapse.server.HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.directory_handler = hs.get_handlers().directory_handler + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + alias_list = await self.directory_handler.get_aliases_for_room( + requester, room_id + ) + + return 200, {"aliases": alias_list} + + class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) @@ -827,14 +878,13 @@ class SearchRestServlet(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = yield self.handlers.search_handler.search( + results = await self.handlers.search_handler.search( requester.user, content, batch ) @@ -849,11 +899,10 @@ class JoinedRoomsRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} @@ -909,6 +958,7 @@ def register_servlets(hs, http_server): JoinedRoomsRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) + RoomAliasListServlet(hs).register(http_server) def register_deprecated_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 2afdbb89e5..747d46eac2 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -17,8 +17,6 @@ import base64 import hashlib import hmac -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req( request, self.hs.config.turn_allow_guests ) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 8250ae0ae1..bc11b4dda4 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -32,7 +32,7 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): Args: path_regex (str): The regex string to match. This should NOT have a ^ - as this will be prefixed. + as this will be prefixed. Returns: SRE_Pattern """ @@ -78,7 +78,7 @@ def interactive_auth_handler(orig): """ def wrapped(*args, **kwargs): - res = defer.maybeDeferred(orig, *args, **kwargs) + res = defer.ensureDeferred(orig(*args, **kwargs)) res.addErrback(_catch_incomplete_interactive_auth) return res diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 2ea515d2f6..d4f721b6b9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,8 +18,6 @@ import logging from six.moves import http_client -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour @@ -32,6 +30,7 @@ from synapse.http.servlet import ( ) from synapse.push.mailer import Mailer, load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -67,11 +66,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User password resets have been disabled due to lack of email config" ) raise SynapseError( @@ -84,6 +82,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Extract params from body client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -95,25 +95,23 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) if existing_user_id is None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - # Have the configured identity server handle the request - if not self.hs.config.account_threepid_delegate_email: - logger.warn( - "No upstream email account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Password reset by email is not supported on this homeserver" - ) + assert self.hs.config.account_threepid_delegate_email - ret = yield self.identity_handler.requestEmailToken( + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -122,7 +120,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) else: # Send password reset emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -136,71 +134,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): return 200, ret -class MsisdnPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/password/msisdn/requestToken$") - - def __init__(self, hs): - super(MsisdnPasswordRequestTokenRestServlet, self).__init__() - self.hs = hs - self.datastore = self.hs.get_datastore() - self.identity_handler = hs.get_handlers().identity_handler - - @defer.inlineCallbacks - def on_POST(self, request): - body = parse_json_object_from_request(request) - - assert_params_in_dict( - body, ["client_secret", "country", "phone_number", "send_attempt"] - ) - client_secret = body["client_secret"] - country = body["country"] - phone_number = body["phone_number"] - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - msisdn = phone_number_to_msisdn(country, phone_number) - - if not check_3pid_allowed(self.hs, "msisdn", msisdn): - raise SynapseError( - 403, - "Account phone numbers are not authorized on this server", - Codes.THREEPID_DENIED, - ) - - existing_user_id = yield self.datastore.get_user_id_by_threepid( - "msisdn", msisdn - ) - - if existing_user_id is None: - raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) - - if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, - "Password reset by phone number is not supported on this homeserver", - ) - - ret = yield self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, - country, - phone_number, - client_secret, - send_attempt, - next_link, - ) - - return 200, ret - - class PasswordResetSubmitTokenServlet(RestServlet): """Handles 3PID validation token submission""" PATTERNS = client_patterns( - "/password_reset/(?P<medium>[^/]*)/submit_token/*$", releases=(), unstable=True + "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True ) def __init__(self, hs): @@ -214,9 +152,13 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_password_reset_template_failure_html], + ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): # We currently only handle threepid token submissions for email if medium != "email": raise SynapseError( @@ -224,7 +166,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Password reset emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -232,20 +174,21 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) sid = parse_string(request, "sid", required=True) - client_secret = parse_string(request, "client_secret", required=True) token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -261,34 +204,12 @@ class PasswordResetSubmitTokenServlet(RestServlet): request.setResponseCode(e.code) # Show a failure page with a reason - html_template, = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], - ) - template_vars = {"failure_reason": e.msg} - html = html_template.render(**template_vars) + html = self.failure_email_template.render(**template_vars) request.write(html.encode("utf-8")) finish_request(request) - @defer.inlineCallbacks - def on_POST(self, request, medium): - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for password resets" - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["sid", "client_secret", "token"]) - - valid, _ = yield self.store.validate_threepid_session( - body["sid"], body["client_secret"], body["token"], self.clock.time_msec() - ) - response_code = 200 if valid else 400 - - return response_code, {"success": valid} - class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") @@ -299,13 +220,27 @@ class PasswordRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() + self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) + # we do basic sanity checks here because the auth layer will store these + # in sessions. Pull out the new password provided to us. + if "new_password" in body: + new_password = body.pop("new_password") + if not isinstance(new_password, str) or len(new_password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(new_password) + + # If the password is valid, hash it and store it back on the body. + # This ensures that only the hashed password is handled everywhere. + if "new_password_hash" in body: + raise SynapseError(400, "Unexpected property: new_password_hash") + body["new_password_hash"] = await self.auth_handler.hash(new_password) + # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -317,17 +252,23 @@ class PasswordRestServlet(RestServlet): # In the second case, we require a password to confirm their identity. if self.auth.has_access_token(request): - requester = yield self.auth.get_user_by_req(request) - params = yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester = await self.auth.get_user_by_req(request) + params = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) user_id = requester.user.to_string() else: requester = None - result, params, _ = yield self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], + result, params, _ = await self.auth_handler.check_auth( + [[LoginType.EMAIL_IDENTITY]], + request, body, self.hs.get_ip_from_request(request), + "modify your account password", ) if LoginType.EMAIL_IDENTITY in result: @@ -340,7 +281,7 @@ class PasswordRestServlet(RestServlet): # (See add_threepid in synapse/handlers/auth.py) threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] ) if not threepid_user_id: @@ -350,10 +291,13 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password"]) - new_password = params["new_password"] + assert_params_in_dict(params, ["new_password_hash"]) + new_password_hash = params["new_password_hash"] + logout_devices = params.get("logout_devices", True) - yield self._set_password_handler.set_password(user_id, new_password, requester) + await self._set_password_handler.set_password( + user_id, new_password_hash, logout_devices, requester + ) return 200, {} @@ -372,8 +316,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -383,19 +326,23 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) # allow ASes to dectivate their own users if requester.app_service: - yield self._deactivate_account_handler.deactivate_account( + await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase ) return 200, {} - yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "deactivate your account", ) - result = yield self._deactivate_account_handler.deactivate_account( + result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: @@ -416,14 +363,37 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.store = self.hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + template_html, template_text = load_jinja2_templates( + self.config.email_template_dir, + [ + self.config.email_add_threepid_template_html, + self.config.email_add_threepid_template_text, + ], + public_baseurl=self.config.public_baseurl, + ) + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=template_html, + template_text=template_text, + ) + + async def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + body = parse_json_object_from_request(request) - assert_params_in_dict( - body, ["id_server", "client_secret", "email", "send_attempt"] - ) - id_server = "https://" + body["id_server"] # Assume https + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -435,16 +405,42 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( "email", body["email"] ) if existing_user_id is not None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - ret = yield self.identity_handler.requestEmailToken( - id_server, email, client_secret, send_attempt, next_link - ) + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send threepid validation emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_add_threepid_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + return 200, ret @@ -457,15 +453,14 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( - body, - ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + body, ["client_secret", "country", "phone_number", "send_attempt"] ) - id_server = "https://" + body["id_server"] # Assume https client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + country = body["country"] phone_number = body["phone_number"] send_attempt = body["send_attempt"] @@ -480,17 +475,156 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - ret = yield self.identity_handler.requestMsisdnToken( - id_server, country, phone_number, client_secret, send_attempt, next_link + if not self.hs.config.account_threepid_delegate_msisdn: + logger.warning( + "No upstream msisdn account_threepid_delegate configured on the server to " + "handle this request" + ) + raise SynapseError( + 400, + "Adding phone numbers to user account is not supported by this homeserver", + ) + + ret = await self.identity_handler.requestMsisdnToken( + self.hs.config.account_threepid_delegate_msisdn, + country, + phone_number, + client_secret, + send_attempt, + next_link, ) + return 200, ret +class AddThreepidEmailSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding an email to a user's account""" + + PATTERNS = client_patterns( + "/add_threepid/email/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_add_threepid_template_failure_html], + ) + + async def on_GET(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + raise SynapseError( + 400, + "This homeserver is not validating threepids. Use an identity server " + "instead.", + ) + + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warning( + "Not redirecting to next_link as it is a local file: address" + ) + else: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None + + # Otherwise show the success template + html = self.config.email_add_threepid_template_success_html_content + request.setResponseCode(200) + except ThreepidValidationError as e: + request.setResponseCode(e.code) + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html = self.failure_email_template.render(**template_vars) + + request.write(html.encode("utf-8")) + finish_request(request) + + +class AddThreepidMsisdnSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding a phone number to a user's + account + """ + + PATTERNS = client_patterns( + "/add_threepid/msisdn/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.identity_handler = hs.get_handlers().identity_handler + + async def on_POST(self, request): + if not self.config.account_threepid_delegate_msisdn: + raise SynapseError( + 400, + "This homeserver is not validating phone numbers. Use an identity server " + "instead.", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["client_secret", "sid", "token"]) + assert_valid_client_secret(body["client_secret"]) + + # Proxy submit_token request to msisdn threepid delegate + response = await self.identity_handler.proxy_msisdn_submit_token( + self.config.account_threepid_delegate_msisdn, + body["client_secret"], + body["sid"], + body["token"], + ) + return 200, response + + class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") @@ -502,16 +636,21 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) + threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() body = parse_json_object_from_request(request) threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") @@ -519,34 +658,111 @@ class ThreepidRestServlet(RestServlet): raise SynapseError( 400, "Missing param three_pid_creds", Codes.MISSING_PARAM ) + assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) + + sid = threepid_creds["sid"] + client_secret = threepid_creds["client_secret"] + assert_valid_client_secret(client_secret) + + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} + + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED + ) + + +class ThreepidAddRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True) + + def __init__(self, hs): + super(ThreepidAddRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() + body = parse_json_object_from_request(request) - # Specify None as the identity server to retrieve it from the request body instead - threepid = yield self.identity_handler.threepid_from_creds(None, threepid_creds) + assert_params_in_dict(body, ["client_secret", "sid"]) + sid = body["sid"] + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) - if not threepid: - raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a third-party identifier to your account", + ) - for reqd in ["medium", "address", "validated_at"]: - if reqd not in threepid: - logger.warn("Couldn't add 3pid: invalid response from ID server") - raise SynapseError(500, "Invalid response from ID Server") + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} - yield self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) - if "bind" in body and body["bind"]: - logger.debug("Binding threepid %s to %s", threepid, user_id) - yield self.identity_handler.bind_threepid(threepid_creds, user_id) + +class ThreepidBindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/bind$", releases=(), unstable=True) + + def __init__(self, hs): + super(ThreepidBindRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) + id_server = body["id_server"] + sid = body["sid"] + id_access_token = body.get("id_access_token") # optional + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + await self.identity_handler.bind_threepid( + client_secret, sid, user_id, id_server, id_access_token + ) return 200, {} class ThreepidUnbindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/unbind$") + PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True) def __init__(self, hs): super(ThreepidUnbindRestServlet, self).__init__() @@ -555,12 +771,11 @@ class ThreepidUnbindRestServlet(RestServlet): self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) @@ -570,7 +785,7 @@ class ThreepidUnbindRestServlet(RestServlet): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past - result = yield self.identity_handler.try_unbind_threepid( + result = await self.identity_handler.try_unbind_threepid( requester.user.to_string(), {"address": address, "medium": medium, "id_server": id_server}, ) @@ -582,19 +797,24 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: - ret = yield self.auth_handler.delete_threepid( + ret = await self.auth_handler.delete_threepid( user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: @@ -619,22 +839,24 @@ class WhoamiRestServlet(RestServlet): super(WhoamiRestServlet, self).__init__() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) return 200, {"user_id": requester.user.to_string()} def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) - MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) PasswordResetSubmitTokenServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) + AddThreepidEmailSubmitTokenServlet(hs).register(http_server) + AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) + ThreepidAddRestServlet(hs).register(http_server) + ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f0db204ffa..c1d4cd0caf 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -40,16 +38,19 @@ class AccountDataServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + + async def on_PUT(self, request, user_id, account_data_type): + if self._is_worker: + raise Exception("Cannot handle PUT /account_data on worker") - @defer.inlineCallbacks - def on_PUT(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_account_data_for_user( + max_id = await self.store.add_account_data_for_user( user_id, account_data_type, body ) @@ -57,13 +58,12 @@ class AccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_global_account_data_by_type_for_user( + event = await self.store.get_global_account_data_by_type_for_user( account_data_type, user_id ) @@ -90,10 +90,13 @@ class RoomAccountDataServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + + async def on_PUT(self, request, user_id, room_id, account_data_type): + if self._is_worker: + raise Exception("Cannot handle PUT /account_data on worker") - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -106,7 +109,7 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -114,13 +117,12 @@ class RoomAccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_account_data_for_room_and_type( + event = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type ) diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 33f6a23028..2f10fa64e2 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet @@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet): self.success_html = hs.config.account_validity.account_renewed_html_content self.failure_html = hs.config.account_validity.invalid_token_html_content - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = yield self.account_activity_handler.renew_account( + token_valid = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) @@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.write(response.encode("utf8")) finish_request(request) - defer.returnValue(None) class AccountValiditySendMailServlet(RestServlet): @@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet): self.auth = hs.get_auth() self.account_validity = self.hs.config.account_validity - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) - requester = yield self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() - yield self.account_activity_handler.send_renewal_email_to_user(user_id) + await self.account_activity_handler.send_renewal_email_to_user(user_id) - defer.returnValue((200, {})) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index f21aff39e5..75590ebaeb 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX @@ -132,7 +130,22 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - def on_GET(self, request, stagetype): + # SSO configuration. + self._cas_enabled = hs.config.cas_enabled + if self._cas_enabled: + self._cas_handler = hs.get_cas_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + self._oidc_enabled = hs.config.oidc_enabled + if self._oidc_enabled: + self._oidc_handler = hs.get_oidc_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -144,14 +157,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { "session": session, @@ -160,19 +165,51 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None + + elif stagetype == LoginType.SSO: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + if self._cas_enabled: + # Generate a request to CAS that redirects back to an endpoint + # to verify the successful authentication. + sso_redirect_url = self._cas_handler.get_redirect_url( + {"session": session}, + ) + + elif self._saml_enabled: + # Some SAML identity providers (e.g. Google) require a + # RelayState parameter on requests. It is not necessary here, so + # pass in a dummy redirect URL (which will never get used). + client_redirect_url = b"unused" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + + elif self._oidc_enabled: + client_redirect_url = b"" + sso_redirect_url = await self._oidc_handler.handle_redirect_request( + request, client_redirect_url, session + ) + + else: + raise SynapseError(400, "Homeserver not configured for SSO.") + + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + else: raise SynapseError(404, "Unknown auth stage type") - @defer.inlineCallbacks - def on_POST(self, request, stagetype): + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + + async def on_POST(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -186,7 +223,7 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) @@ -199,23 +236,10 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - - return None elif stagetype == LoginType.TERMS: - if ("session" not in request.args or len(request.args["session"])) == 0: - raise SynapseError(400, "No session supplied") - - session = request.args["session"][0] authdict = {"session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) @@ -232,17 +256,22 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index acd58af193..fe9d019c44 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet @@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - user = yield self.store.get_user_by_id(requester.user.to_string()) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = await self.store.get_user_by_id(requester.user.to_string()) change_password = bool(user["password_hash"]) response = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 26d0235208..c0714fcfb1 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api import errors from synapse.http.servlet import ( RestServlet, @@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - devices = yield self.device_handler.get_devices_by_user( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + devices = await self.device_handler.get_devices_by_user( requester.user.to_string() ) return 200, {"devices": devices} @@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -84,11 +80,15 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) - yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove device(s) from your account", ) - yield self.device_handler.delete_devices( + await self.device_handler.delete_devices( requester.user.to_string(), body["devices"] ) return 200, {} @@ -108,18 +108,16 @@ class DeviceRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - device = yield self.device_handler.get_device( + async def on_GET(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + device = await self.device_handler.get_device( requester.user.to_string(), device_id ) return 200, device @interactive_auth_handler - @defer.inlineCallbacks - def on_DELETE(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, device_id): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -132,19 +130,22 @@ class DeviceRestServlet(RestServlet): else: raise - yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove a device from your account", ) - yield self.device_handler.delete_device(requester.user.to_string(), device_id) + await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) - yield self.device_handler.update_device( + await self.device_handler.update_device( requester.user.to_string(), device_id, body ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index c6ddf24c8d..b28da017cd 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,9 +15,7 @@ import logging -from twisted.internet import defer - -from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_GET(self, request, user_id, filter_id): + async def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot get filters for other users") @@ -52,13 +49,15 @@ class GetFilterRestServlet(RestServlet): raise SynapseError(400, "Invalid filter_id") try: - filter = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user_localpart=target_user.localpart, filter_id=filter_id ) + except StoreError as e: + if e.code != 404: + raise + raise NotFoundError("No such filter") - return 200, filter.get_filter_json() - except (KeyError, StoreError): - raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) + return 200, filter_collection.get_filter_json() class CreateFilterRestServlet(RestServlet): @@ -70,11 +69,10 @@ class CreateFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_POST(self, request, user_id): + async def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot create filters for other users") @@ -85,7 +83,7 @@ class CreateFilterRestServlet(RestServlet): content = parse_json_object_from_request(request) set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) - filter_id = yield self.filtering.add_user_filter( + filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content ) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 999a0fa80c..d84a6d7e11 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID @@ -38,24 +36,22 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - group_description = yield self.groups_handler.get_group_profile( + group_description = await self.groups_handler.get_group_profile( group_id, requester_user_id ) return 200, group_description - @defer.inlineCallbacks - def on_POST(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - yield self.groups_handler.update_group_profile( + await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - get_group_summary = yield self.groups_handler.get_group_summary( + get_group_summary = await self.groups_handler.get_group_summary( group_id, requester_user_id ) @@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_room( + resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, room_id=room_id, @@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_room( + resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_category( + category = await self.groups_handler.get_group_category( group_id, requester_user_id, category_id=category_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_category( + resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_category( + resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_categories( + category = await self.groups_handler.get_group_categories( group_id, requester_user_id ) @@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_role( + category = await self.groups_handler.get_group_role( group_id, requester_user_id, role_id=role_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_role( + resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_role( + resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_roles( + category = await self.groups_handler.get_group_roles( group_id, requester_user_id ) @@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_user( + resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, user_id=user_id, @@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_user( + resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group( + result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group( + result = await self.groups_handler.get_users_in_group( group_id, requester_user_id ) @@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_invited_users_in_group( + result = await self.groups_handler.get_invited_users_in_group( group_id, requester_user_id ) @@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.set_group_join_policy( + result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() # TODO: Create group on remote server @@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() - result = yield self.groups_handler.create_group( + result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.add_room_to_group( + result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) return 200, result - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.remove_room_from_group( + result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id, config_key): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id, config_key): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.update_room_in_group( + result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) - result = yield self.groups_handler.invite( + result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.join_group( + result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.accept_invite( + result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity(group_id, requester_user_id, publicise) + await self.store.update_group_publicity(group_id, requester_user_id, publicise) return 200, {} @@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, user_id): + await self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user(user_id) + result = await self.groups_handler.get_publicised_groups_for_user(user_id) return 200, result @@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) + result = await self.groups_handler.bulk_get_publicised_groups(user_ids) return 200, result @@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_joined_groups(requester_user_id) + result = await self.groups_handler.get_joined_groups(requester_user_id) return 200, result diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 2e680134a0..24bb090822 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, @@ -27,7 +26,7 @@ from synapse.http.servlet import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.types import StreamToken -from ._base import client_patterns +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ class KeyUploadServlet(RestServlet): "device_id": "<device_id>", "valid_until_ts": <millisecond_timestamp>, "algorithms": [ - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ] "keys": { "<algorithm>:<device_id>": "<key_base64>", @@ -70,9 +69,8 @@ class KeyUploadServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() @trace(opname="upload_keys") - @defer.inlineCallbacks - def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -102,7 +100,7 @@ class KeyUploadServlet(RestServlet): 400, "To upload keys, you must pass device_id when authenticating" ) - result = yield self.e2e_keys_handler.upload_keys_for_user( + result = await self.e2e_keys_handler.upload_keys_for_user( user_id, device_id, body ) return 200, result @@ -126,7 +124,7 @@ class KeyQueryServlet(RestServlet): "device_id": "<device_id>", // Duplicated to be signed "valid_until_ts": <millisecond_timestamp>, "algorithms": [ // List of supported algorithms - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ], "keys": { // Must include a ed25519 signing key "<algorithm>:<key_id>": "<key_base64>", @@ -153,12 +151,12 @@ class KeyQueryServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body, timeout) + result = await self.e2e_keys_handler.query_devices(body, timeout, user_id) return 200, result @@ -183,9 +181,8 @@ class KeyChangesServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from") set_tag("from", from_token_string) @@ -198,7 +195,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed(user_id, from_token) + results = await self.device_handler.get_user_ids_changed(user_id, from_token) return 200, results @@ -229,12 +226,100 @@ class OneTimeKeyServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) + result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) + return 200, result + + +class SigningKeyUploadServlet(RestServlet): + """ + POST /keys/device_signing/upload HTTP/1.1 + Content-Type: application/json + + { + } + """ + + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SigningKeyUploadServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a device signing key to your account", + ) + + result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + return 200, result + + +class SignaturesUploadServlet(RestServlet): + """ + POST /keys/signatures/upload HTTP/1.1 + Content-Type: application/json + + { + "@alice:example.com": { + "<device_id>": { + "user_id": "<user_id>", + "device_id": "<device_id>", + "algorithms": [ + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures": { + "<signing_user_id>": { + "<algorithm>:<signing_key_base64>": "<signature_base64>>" + } + } + } + } + } + """ + + PATTERNS = client_patterns("/keys/signatures/upload$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SignaturesUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + result = await self.e2e_keys_handler.upload_signatures_for_device_keys( + user_id, body + ) return 200, result @@ -243,3 +328,5 @@ def register_servlets(hs, http_server): KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) + SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 10c1ad5b07..aa911d75ee 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet): self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() from_token = parse_string(request, "from", required=False) @@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet): limit = min(limit, 500) - push_actions = yield self.store.get_push_actions_for_user( + push_actions = await self.store.get_push_actions_for_user( user_id, from_token, limit, only_highlight=(only == "highlight") ) - receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + receipts_by_room = await self.store.get_receipts_for_user_with_orderings( user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] - notif_events = yield self.store.get_events(notif_event_ids) + notif_events = await self.store.get_events(notif_event_ids) returned_push_actions = [] @@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet): "actions": pa["actions"], "ts": pa["received_ts"], "event": ( - yield self._event_serializer.serialize_event( + await self._event_serializer.serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index b4925c0f59..6ae9a5a8e9 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string @@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet): self.clock = hs.get_clock() self.server_name = hs.config.server_name - @defer.inlineCallbacks - def on_POST(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet): token = random_string(24) ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS - yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) return ( 200, diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py new file mode 100644 index 0000000000..968403cca4 --- /dev/null +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(PasswordPolicyServlet, self).__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index b3bf8567e1..67cbc37312 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) read_event_id = body.get("m.read", None) if read_event_id: - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, "m.read", user_id=requester.user.to_string(), @@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet): read_marker_event_id = body.get("m.fully_read", None) if read_marker_event_id: - yield self.read_marker_handler.received_client_read_marker( + await self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), event_id=read_marker_event_id, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 0dab03d227..92555bd4a9 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet @@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id, receipt_type, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5c7a5f3579..b9ffe86b2a 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,24 +16,28 @@ import hmac import logging +from typing import List, Union from six import string_types -from twisted.internet import defer - import synapse +import synapse.api.auth import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, - LimitExceededError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.config import ConfigError +from synapse.config.captcha import CaptchaConfig +from synapse.config.consent_config import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved +from synapse.handlers.auth import AuthHandler from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -44,6 +48,7 @@ from synapse.http.servlet import ( from synapse.push.mailer import load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -96,11 +101,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Email registration has been disabled due to lack of email config" ) raise SynapseError( @@ -112,6 +116,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # Extract params from body client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -123,24 +129,23 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", body["email"] ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - if not self.hs.config.account_threepid_delegate_email: - logger.warn( - "No upstream email account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Registration by email is not supported on this homeserver" - ) + assert self.hs.config.account_threepid_delegate_email - ret = yield self.identity_handler.requestEmailToken( + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -149,7 +154,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) else: # Send registration emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -175,8 +180,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( @@ -197,17 +201,22 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError( 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -215,7 +224,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Registration by phone number is not supported on this homeserver" ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -246,15 +255,26 @@ class RegistrationSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request, medium): + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_registration_template_failure_html], + ) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_registration_template_failure_html], + ) + + async def on_GET(self, request, medium): if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User registration via email has been disabled due to lack of email config" ) raise SynapseError( @@ -268,14 +288,14 @@ class RegistrationSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -289,17 +309,11 @@ class RegistrationSubmitTokenServlet(RestServlet): request.setResponseCode(200) except ThreepidValidationError as e: - # Show a failure page with a reason request.setResponseCode(e.code) # Show a failure page with a reason - html_template, = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - template_vars = {"failure_reason": e.msg} - html = html_template.render(**template_vars) + html = self.failure_email_template.render(**template_vars) request.write(html.encode("utf-8")) finish_request(request) @@ -332,15 +346,19 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + ip = self.hs.get_ip_from_request(request) with self.ratelimiter.ratelimit(ip) as wait_deferred: - yield wait_deferred + await wait_deferred username = parse_string(request, "username", required=True) - yield self.registration_handler.check_username(username) + await self.registration_handler.check_username(username) return 200, {"available": True} @@ -364,50 +382,46 @@ class RegisterRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self._registration_flows = _calculate_registration_flows( + hs.config, self.auth_handler + ) + @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body, address=client_addr) + ret = await self._do_guest_registration(body, address=client_addr) return ret elif kind != b"user": raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind,) + "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. if "password" in body: - if ( - not isinstance(body["password"], string_types) - or len(body["password"]) > 512 - ): + password = body.pop("password") + if not isinstance(password, string_types) or len(password) > 512: raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + # If the password is valid, hash it and store it back on the body. + # This ensures that only the hashed password is handled everywhere. + if "password_hash" in body: + raise SynapseError(400, "Unexpected property: password_hash") + body["password_hash"] = await self.auth_handler.hash(password) desired_username = None if "username" in body: @@ -420,7 +434,7 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = yield self.auth.get_appservice_by_req(request) + appservice = await self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -440,7 +454,7 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) if isinstance(desired_username, string_types): - result = yield self._do_appservice_registration( + result = await self._do_appservice_registration( desired_username, access_token, body ) return 200, result # we throw for non 200 responses @@ -460,12 +474,12 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password" not in body: + if "initial_device_display_name" in body and "password_hash" not in body: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params - logger.warn("Ignoring initial_device_display_name without password") + logger.warning("Ignoring initial_device_display_name without password") del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) @@ -475,80 +489,23 @@ class RegisterRestServlet(RestServlet): # registered a user for this session, so we could just return the # user here. We carry on and go through the auth checks though, # for paranoia. - registered_user_id = self.auth_handler.get_session_data( + registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) if desired_username is not None: - yield self.registration_handler.check_username( + await self.registration_handler.check_username( desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, ) - # FIXME: need a better error than "no auth flow found" for scenarios - # where we required 3PID for registration but the user didn't give one - require_email = "email" in self.hs.config.registrations_require_3pid - require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid - - show_msisdn = True - if self.hs.config.disable_msisdn_registration: - show_msisdn = False - require_msisdn = False - - flows = [] - if self.hs.config.enable_registration_captcha: - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - # Also add a dummy flow here, otherwise if a client completes - # recaptcha first we'll assume they were going for this flow - # and complete the request, when they could have been trying to - # complete one of the flows with email/msisdn auth. - flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]]) - - if show_msisdn: - # only support the MSISDN-only flow if we don't require email 3PIDs - if not require_email: - flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) - # always let users provide both MSISDN & email - flows.extend( - [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]] - ) - else: - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - flows.extend([[LoginType.DUMMY]]) - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([[LoginType.EMAIL_IDENTITY]]) - - if show_msisdn: - # only support the MSISDN-only flow if we don't require email 3PIDs - if not require_email or require_msisdn: - flows.extend([[LoginType.MSISDN]]) - # always let users provide both MSISDN & email - flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]) - - # Append m.login.terms to all flows if we're requiring consent - if self.hs.config.user_consent_at_registration: - new_flows = [] - for flow in flows: - inserted = False - # m.login.terms should go near the end but before msisdn or email auth - for i, stage in enumerate(flow): - if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN: - flow.insert(i, LoginType.TERMS) - inserted = True - break - if not inserted: - flow.append(LoginType.TERMS) - flows.extend(new_flows) - - auth_result, params, session_id = yield self.auth_handler.check_auth( - flows, body, self.hs.get_ip_from_request(request) + auth_result, params, session_id = await self.auth_handler.check_auth( + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", ) # Check that we're not trying to register a denied 3pid. @@ -579,11 +536,11 @@ class RegisterRestServlet(RestServlet): registered = False else: # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password"]) + assert_params_in_dict(params, ["password_hash"]) desired_username = params.get("username", None) guest_access_token = params.get("guest_access_token", None) - new_password = params.get("password", None) + new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -603,7 +560,7 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( medium, address ) @@ -614,9 +571,9 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) - registered_user_id = yield self.registration_handler.register_user( + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password=new_password, + password_hash=new_password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, @@ -627,22 +584,22 @@ class RegisterRestServlet(RestServlet): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(registered_user_id) + await self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) - self.auth_handler.set_session_data( + await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) registered = True - return_dict = yield self._create_registration_details( + return_dict = await self._create_registration_details( registered_user_id, params ) if registered: - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, access_token=return_dict.get("access_token"), @@ -653,15 +610,13 @@ class RegisterRestServlet(RestServlet): def on_OPTIONS(self, _): return 200, {} - @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token, body): - user_id = yield self.registration_handler.appservice_register( + async def _do_appservice_registration(self, username, as_token, body): + user_id = await self.registration_handler.appservice_register( username, as_token ) - return (yield self._create_registration_details(user_id, body)) + return await self._create_registration_details(user_id, body) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, params): + async def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -677,18 +632,17 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False ) result.update({"access_token": access_token, "device_id": device_id}) return result - @defer.inlineCallbacks - def _do_guest_registration(self, params, address=None): + async def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( make_guest=True, address=address ) @@ -696,7 +650,7 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) @@ -711,6 +665,83 @@ class RegisterRestServlet(RestServlet): ) +def _calculate_registration_flows( + # technically `config` has to provide *all* of these interfaces, not just one + config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], + auth_handler: AuthHandler, +) -> List[List[str]]: + """Get a suitable flows list for registration + + Args: + config: server configuration + auth_handler: authorization handler + + Returns: a list of supported flows + """ + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = "email" in config.registrations_require_3pid + require_msisdn = "msisdn" in config.registrations_require_3pid + + show_msisdn = True + show_email = True + + if config.disable_msisdn_registration: + show_msisdn = False + require_msisdn = False + + enabled_auth_types = auth_handler.get_enabled_auth_types() + if LoginType.EMAIL_IDENTITY not in enabled_auth_types: + show_email = False + if require_email: + raise ConfigError( + "Configuration requires email address at registration, but email " + "validation is not configured" + ) + + if LoginType.MSISDN not in enabled_auth_types: + show_msisdn = False + if require_msisdn: + raise ConfigError( + "Configuration requires msisdn at registration, but msisdn " + "validation is not configured" + ) + + flows = [] + + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + # Add a dummy step here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.append([LoginType.DUMMY]) + + # only support the email-only flow if we don't require MSISDN 3PIDs + if show_email and not require_msisdn: + flows.append([LoginType.EMAIL_IDENTITY]) + + # only support the MSISDN-only flow if we don't require email 3PIDs + if show_msisdn and not require_email: + flows.append([LoginType.MSISDN]) + + if show_email and show_msisdn: + # always let users provide both MSISDN & email + flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) + + # Prepend m.login.terms to all flows if we're requiring consent + if config.user_consent_at_registration: + for flow in flows: + flow.insert(0, LoginType.TERMS) + + # Prepend recaptcha to all flows if we're requiring captcha + if config.enable_registration_captcha: + for flow in flows: + flow.insert(0, LoginType.RECAPTCHA) + + return flows + + def register_servlets(hs, http_server): EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 040b37c504..89002ffbff 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -21,8 +21,6 @@ any time to reflect changes in the MSC. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet): request, self.on_PUT_or_POST, request, *args, **kwargs ) - @defer.inlineCallbacks - def on_PUT_or_POST( + async def on_PUT_or_POST( self, request, room_id, parent_id, relation_type, event_type, txn_id=None ): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: # Add relations to a membership is meaningless, so we just deny it @@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) @@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True ) # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") @@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_relations_for_event( + pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] ) @@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet): # We set bundle_aggregations to False when retrieving the original # event because we want the content before relations were applied to # it. - original_event = yield self._event_serializer.serialize_event( + original_event = await self._event_serializer.serialize_event( event, now, bundle_aggregations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( events, now, bundle_aggregations=False ) @@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet): self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to # view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet): if to_token: to_token = AggregationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_aggregation_groups_for_event( + pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, event_type=event_type, limit=limit, @@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - result = yield self.store.get_relations_for_event( + result = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in result.chunk] ) now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events(events, now) + events = await self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = events diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e7449864cd..f067b5edac 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -18,8 +18,6 @@ import logging from six import string_types from six.moves import http_client -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, @@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet): Codes.BAD_JSON, ) - yield self.store.add_event_report( + await self.store.add_event_report( room_id=room_id, event_id=event_id, user_id=user_id, diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index df4f44cd36..59529707df 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, session_id): + async def on_PUT(self, request, room_id, session_id): """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) version = parse_string(request, "version") @@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, {} + ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret - @defer.inlineCallbacks - def on_GET(self, request, room_id, session_id): + async def on_GET(self, request, room_id, session_id): """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) - room_keys = yield self.e2e_room_keys_handler.get_room_keys( + room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id ) @@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - @defer.inlineCallbacks - def on_DELETE(self, request, room_id, session_id): + async def on_DELETE(self, request, room_id, session_id): """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -235,14 +230,14 @@ class RoomKeysServlet(RestServlet): the version must already have been created via the /change_secret API. """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - yield self.e2e_room_keys_handler.delete_room_keys( + ret = await self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return 200, {} + return 200, ret class RoomKeysNewVersionServlet(RestServlet): @@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet): "version": 12345 } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) + new_version = await self.e2e_room_keys_handler.create_version(user_id, info) return 200, {"version": new_version} # we deliberately don't have a PUT /version, as these things really should @@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_GET(self, request, version): + async def on_GET(self, request, version): """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet): "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) + info = await self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - @defer.inlineCallbacks - def on_DELETE(self, request, version): + async def on_DELETE(self, request, version): """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet): if version is None: raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version(user_id, version) + await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, version): + async def on_PUT(self, request, version): """ Update the information about a given version of the user's room_keys backup. @@ -375,14 +366,14 @@ class RoomKeysVersionServlet(RestServlet): "ed25519:something": "hijklmnop" } }, - "version": "42" + "version": "12345" } HTTP/1.1 200 OK Content-Type: application/json {} """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) @@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet): 400, "No version specified to update", Codes.MISSING_PARAM ) - yield self.e2e_room_keys_handler.update_version(user_id, version, info) + await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index d2c3316eb7..f357015a70 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( @@ -59,22 +57,22 @@ class RoomUpgradeRestServlet(RestServlet): self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self._auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) new_version = content["new_version"] - if new_version not in KNOWN_ROOM_VERSIONS: + new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) + if new_version is None: raise SynapseError( 400, "Your homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = yield self._room_creation_handler.upgrade_room( + new_room_id = await self._room_creation_handler.upgrade_room( requester, room_id, new_version ) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d90e52ed1a..db829f3098 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Tuple from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request @@ -51,19 +50,18 @@ class SendToDeviceRestServlet(servlet.RestServlet): request, self._put, request, message_type, txn_id ) - @defer.inlineCallbacks - def _put(self, request, message_type, txn_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def _put(self, request, message_type, txn_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) sender_user_id = requester.user.to_string() - yield self.device_message_handler.send_device_message( + await self.device_message_handler.send_device_message( sender_user_id, message_type, content["messages"] ) - response = (200, {}) + response = (200, {}) # type: Tuple[int, dict] return response diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c98c5a3802..8fa68dd37f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -18,10 +18,8 @@ import logging from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import PresenceState -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, @@ -74,7 +72,7 @@ class SyncRestServlet(RestServlet): """ PATTERNS = client_patterns("/sync$") - ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) + ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): super(SyncRestServlet, self).__init__() @@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. @@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id @@ -112,32 +109,44 @@ class SyncRestServlet(RestServlet): full_state = parse_boolean(request, "full_state", default=False) logger.debug( - "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" - % (user, timeout, since, set_presence, filter_id, device_id) + "/sync: user=%r, timeout=%r, since=%r, " + "set_presence=%r, filter_id=%r, device_id=%r", + user, + timeout, + since, + set_presence, + filter_id, + device_id, ) request_key = (user, timeout, since, filter_id, full_state, device_id) - if filter_id: - if filter_id.startswith("{"): - try: - filter_object = json.loads(filter_id) - set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit - ) - except Exception: - raise SynapseError(400, "Invalid filter JSON") - self.filtering.check_valid_filter(filter_object) - filter = FilterCollection(filter_object) - else: - filter = yield self.filtering.get_user_filter(user.localpart, filter_id) + if filter_id is None: + filter_collection = DEFAULT_FILTER_COLLECTION + elif filter_id.startswith("{"): + try: + filter_object = json.loads(filter_id) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) + except Exception: + raise SynapseError(400, "Invalid filter JSON") + self.filtering.check_valid_filter(filter_object) + filter_collection = FilterCollection(filter_object) else: - filter = DEFAULT_FILTER_COLLECTION + try: + filter_collection = await self.filtering.get_user_filter( + user.localpart, filter_id + ) + except StoreError as err: + if err.code != 404: + raise + # fix up the description and errcode to be more useful + raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM) sync_config = SyncConfig( user=user, - filter_collection=filter, + filter_collection=filter_collection, is_guest=requester.is_guest, request_key=request_key, device_id=device_id, @@ -149,20 +158,20 @@ class SyncRestServlet(RestServlet): since_token = None # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(user.to_string()) + await self._server_notices_sender.on_user_syncing(user.to_string()) affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state( + await self.presence_handler.set_state( user, {"presence": set_presence}, True ) - context = yield self.presence_handler.user_syncing( + context = await self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence ) with context: - sync_result = yield self.sync_handler.wait_for_sync_for_user( + sync_result = await self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, @@ -170,14 +179,13 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() - response_content = yield self.encode_response( - time_now, sync_result, requester.access_token_id, filter + response_content = await self.encode_response( + time_now, sync_result, requester.access_token_id, filter_collection ) return 200, response_content - @defer.inlineCallbacks - def encode_response(self, time_now, sync_result, access_token_id, filter): + async def encode_response(self, time_now, sync_result, access_token_id, filter): if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == "federation": @@ -185,7 +193,7 @@ class SyncRestServlet(RestServlet): else: raise Exception("Unknown event format %s" % (filter.event_format,)) - joined = yield self.encode_joined( + joined = await self.encode_joined( sync_result.joined, time_now, access_token_id, @@ -193,11 +201,11 @@ class SyncRestServlet(RestServlet): event_formatter, ) - invited = yield self.encode_invited( + invited = await self.encode_invited( sync_result.invited, time_now, access_token_id, event_formatter ) - archived = yield self.encode_archived( + archived = await self.encode_archived( sync_result.archived, time_now, access_token_id, @@ -238,8 +246,9 @@ class SyncRestServlet(RestServlet): ] } - @defer.inlineCallbacks - def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_joined( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the joined rooms in a sync result @@ -260,7 +269,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -271,8 +280,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_invited(self, rooms, time_now, token_id, event_formatter): + async def encode_invited(self, rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -292,7 +300,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = yield self._event_serializer.serialize_event( + invite = await self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -307,8 +315,9 @@ class SyncRestServlet(RestServlet): return invited - @defer.inlineCallbacks - def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_archived( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the archived rooms in a sync result @@ -329,7 +338,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -340,8 +349,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_room( + async def encode_room( self, room, time_now, token_id, joined, only_fields, event_formatter ): """ @@ -382,15 +390,15 @@ class SyncRestServlet(RestServlet): # We've had bug reports that events were coming down under the # wrong room. if event.room_id != room.room_id: - logger.warn( + logger.warning( "Event %r is under room %r instead of %r", event.event_id, room.room_id, event.room_id, ) - serialized_state = yield serialize(state_events) - serialized_timeline = yield serialize(timeline_events) + serialized_state = await serialize(state_events) + serialized_timeline = await serialize(timeline_events) account_data = room.account_data diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 3b555669a0..a3f12e8a77 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -37,13 +35,12 @@ class TagListServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) return 200, {"tags": tags} @@ -64,27 +61,25 @@ class TagServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) + max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 2e8d672471..23709960ad 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet @@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols() + protocols = await self.appservice_handler.get_3pe_protocols() return 200, protocols @@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols( + protocols = await self.appservice_handler.get_3pe_protocols( only_protocol=protocol ) if protocol in protocols: @@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields ) @@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields ) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 2da0f55811..83f3b6b70a 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet @@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 2863affbab..bef91a2d3e 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Searches for users in directory Returns: @@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet): ] } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: @@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet): except Exception: raise SynapseError(400, "`search_term` is required field") - results = yield self.user_directory_handler.search_users( + results = await self.user_directory_handler.search_users( user_id, search_term, limit ) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 0058b6b459..0d668df0b6 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,9 +49,18 @@ class VersionsRestServlet(RestServlet): "r0.3.0", "r0.4.0", "r0.5.0", + "r0.6.0", ], # as per MSC1497: - "unstable_features": {"m.lazy_load_members": True}, + "unstable_features": { + # Implements support for label-based filtering as described in + # MSC2326. + "org.matrix.label_based_filtering": True, + # Implements support for cross signing as described in MSC1756 + "org.matrix.e2e_cross_signing": True, + # Implements additional endpoints as described in MSC2432 + "org.matrix.msc2432": True, + }, }, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 55580bc59e..ab671f7334 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,12 +13,11 @@ # limitations under the License. import logging +from typing import Dict, Set from canonicaljson import encode_canonical_json, json from signedjson.sign import sign_json -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import ( @@ -102,8 +101,8 @@ class RemoteKey(DirectServeResource): @wrap_json_request_handler async def _async_render_GET(self, request): if len(request.postpath) == 1: - server, = request.postpath - query = {server.decode("ascii"): {}} + (server,) = request.postpath + query = {server.decode("ascii"): {}} # type: dict elif len(request.postpath) == 2: server, key_id = request.postpath minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") @@ -124,8 +123,7 @@ class RemoteKey(DirectServeResource): await self.query_keys(request, query, query_remote_on_cache_miss=True) - @defer.inlineCallbacks - def query_keys(self, request, query, query_remote_on_cache_miss=False): + async def query_keys(self, request, query, query_remote_on_cache_miss=False): logger.info("Handling query for keys %r", query) store_queries = [] @@ -142,13 +140,13 @@ class RemoteKey(DirectServeResource): for key_id in key_ids: store_queries.append((server_name, key_id, None)) - cached = yield self.store.get_server_keys_json(store_queries) + cached = await self.store.get_server_keys_json(store_queries) json_results = set() time_now_ms = self.clock.time_msec() - cache_misses = dict() + cache_misses = {} # type: Dict[str, Set[str]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] @@ -214,8 +212,8 @@ class RemoteKey(DirectServeResource): json_results.add(bytes(result["key_json"])) if cache_misses and query_remote_on_cache_miss: - yield self.fetcher.get_keys(cache_misses) - yield self.query_keys(request, query, query_remote_on_cache_miss=False) + await self.fetcher.get_keys(cache_misses) + await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json in json_results: diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py deleted file mode 100644 index 86884c0ef4..0000000000 --- a/synapse/rest/media/v0/content_repository.py +++ /dev/null @@ -1,103 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import logging -import os -import re - -from canonicaljson import json - -from twisted.protocols.basic import FileSender -from twisted.web import resource, server - -from synapse.api.errors import Codes, cs_error -from synapse.http.server import finish_request, respond_with_json_bytes - -logger = logging.getLogger(__name__) - - -class ContentRepoResource(resource.Resource): - """Provides file uploading and downloading. - - Uploads are POSTed to wherever this Resource is linked to. This resource - returns a "content token" which can be used to GET this content again. The - token is typically a path, but it may not be. Tokens can expire, be - one-time uses, etc. - - In this case, the token is a path to the file and contains 3 interesting - sections: - - User ID base64d (for namespacing content to each user) - - random 24 char string - - Content type base64d (so we can return it when clients GET it) - - """ - - isLeaf = True - - def __init__(self, hs, directory): - resource.Resource.__init__(self) - self.hs = hs - self.directory = directory - - def render_GET(self, request): - # no auth here on purpose, to allow anyone to view, even across home - # servers. - - # TODO: A little crude here, we could do this better. - filename = request.path.decode("ascii").split("/")[-1] - # be paranoid - filename = re.sub("[^0-9A-z.-_]", "", filename) - - file_path = self.directory + "/" + filename - - logger.debug("Searching for %s", file_path) - - if os.path.isfile(file_path): - # filename has the content type - base64_contentype = filename.split(".")[1] - content_type = base64.urlsafe_b64decode(base64_contentype) - logger.info("Sending file %s", file_path) - f = open(file_path, "rb") - request.setHeader("Content-Type", content_type) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our matrix - # clients are smart enough to be happy with Cache-Control (right?) - request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - - d = FileSender().beginFileTransfer(f, request) - - # after the file has been sent, clean up and finish the request - def cbFinished(ignored): - f.close() - finish_request(request) - - d.addCallback(cbFinished) - else: - respond_with_json_bytes( - request, - 404, - json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)), - send_cors=True, - ) - - return server.NOT_DONE_YET - - def render_OPTIONS(self, request): - respond_with_json_bytes(request, 200, {}, send_cors=True) - return server.NOT_DONE_YET diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 5fefee4dde..3689777266 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,7 +17,6 @@ import logging import os -from six import PY3 from six.moves import urllib from twisted.internet import defer @@ -30,6 +29,22 @@ from synapse.util.stringutils import is_ascii logger = logging.getLogger(__name__) +# list all text content types that will have the charset default to UTF-8 when +# none is given +TEXT_CONTENT_TYPES = [ + "text/css", + "text/csv", + "text/html", + "text/calendar", + "text/plain", + "text/javascript", + "application/json", + "application/ld+json", + "application/rtf", + "image/svg+xml", + "text/xml", +] + def parse_media_id(request): try: @@ -96,7 +111,14 @@ def add_file_headers(request, media_type, file_size, upload_name): def _quote(x): return urllib.parse.quote(x.encode("utf-8")) - request.setHeader(b"Content-Type", media_type.encode("UTF-8")) + # Default to a UTF-8 charset for text content types. + # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' + if media_type.lower() in TEXT_CONTENT_TYPES: + content_type = media_type + "; charset=UTF-8" + else: + content_type = media_type + + request.setHeader(b"Content-Type", content_type.encode("UTF-8")) if upload_name: # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. # @@ -135,27 +157,25 @@ def add_file_headers(request, media_type, file_size, upload_name): # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = set( - ( - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", - ) -) +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} def _can_encode_filename_as_token(x): @@ -195,7 +215,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam respond_404(request) return - logger.debug("Responding to media request with responder %s") + logger.debug("Responding to media request with responder %s", responder) add_file_headers(request, media_type, file_size, upload_name) try: with responder: @@ -303,23 +323,15 @@ def get_filename_from_headers(headers): upload_name_utf8 = upload_name_utf8[7:] # We have a filename*= section. This MUST be ASCII, and any UTF-8 # bytes are %-quoted. - if PY3: - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - else: - # On Python 2, we first unquote the %-encoded parts and then - # decode it strictly using UTF-8. - try: - upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8") - except UnicodeDecodeError: - pass + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass # If there isn't check for an ascii name. if not upload_name: diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 66a01559e1..24d3ae5bbc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource): b" media-src 'self';" b" object-src 'self';", ) + request.setHeader( + b"Referrer-Policy", b"no-referrer", + ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index b972e152a9..fd10d42f2f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -18,12 +18,12 @@ import errno import logging import os import shutil +from typing import Dict, Tuple from six import iteritems import twisted.internet.error import twisted.web.http -from twisted.internet import defer from twisted.web.resource import Resource from synapse.api.errors import ( @@ -113,15 +113,14 @@ class MediaRepository(object): "update_recently_accessed_media", self._update_recently_accessed ) - @defer.inlineCallbacks - def _update_recently_accessed(self): + async def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() - yield self.store.update_cached_last_access_time( + await self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec() ) @@ -137,8 +136,7 @@ class MediaRepository(object): else: self.recently_accessed_locals.add(media_id) - @defer.inlineCallbacks - def create_content( + async def create_content( self, media_type, upload_name, content, content_length, auth_user ): """Store uploaded content for a local user and return the mxc URL @@ -157,11 +155,11 @@ class MediaRepository(object): file_info = FileInfo(server_name=None, file_id=media_id) - fname = yield self.media_storage.store_file(content, file_info) + fname = await self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -170,12 +168,11 @@ class MediaRepository(object): user_id=auth_user, ) - yield self._generate_thumbnails(None, media_id, media_id, media_type) + await self._generate_thumbnails(None, media_id, media_id, media_type) return "mxc://%s/%s" % (self.server_name, media_id) - @defer.inlineCallbacks - def get_local_media(self, request, media_id, name): + async def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: @@ -189,7 +186,7 @@ class MediaRepository(object): Deferred: Resolves once a response has successfully been written to request """ - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return @@ -203,13 +200,12 @@ class MediaRepository(object): file_info = FileInfo(None, media_id, url_cache=url_cache) - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder( + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( request, responder, media_type, media_length, upload_name ) - @defer.inlineCallbacks - def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: @@ -235,8 +231,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -245,14 +241,13 @@ class MediaRepository(object): media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] - yield respond_with_responder( + await respond_with_responder( request, responder, media_type, media_length, upload_name ) else: respond_404(request) - @defer.inlineCallbacks - def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. @@ -273,8 +268,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -285,8 +280,7 @@ class MediaRepository(object): return media_info - @defer.inlineCallbacks - def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. @@ -298,7 +292,7 @@ class MediaRepository(object): Returns: Deferred[(Responder, media_info)] """ - media_info = yield self.store.get_cached_remote_media(server_name, media_id) + media_info = await self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new @@ -316,19 +310,18 @@ class MediaRepository(object): logger.info("Media is quarantined") raise NotFoundError() - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: return responder, media_info # Failed to find the file anywhere, lets download it. - media_info = yield self._download_remote_file(server_name, media_id, file_id) + media_info = await self._download_remote_file(server_name, media_id, file_id) - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) return responder, media_info - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -350,7 +343,7 @@ class MediaRepository(object): ("/_matrix/media/v1/download", server_name, media_id) ) try: - length, headers = yield self.client.get_file( + length, headers = await self.client.get_file( server_name, request_path, output_stream=f, @@ -363,7 +356,7 @@ class MediaRepository(object): }, ) except RequestSendFailed as e: - logger.warn( + logger.warning( "Request failed fetching remote media %s/%s: %r", server_name, media_id, @@ -372,7 +365,7 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn( + logger.warning( "HTTP error fetching remote media %s/%s: %s", server_name, media_id, @@ -383,10 +376,12 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise except NotRetryingDestination: - logger.warn("Not retrying destination %r", server_name) + logger.warning("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception( @@ -394,7 +389,7 @@ class MediaRepository(object): ) raise SynapseError(502, "Failed to fetch remote media") - yield finish() + await finish() media_type = headers[b"Content-Type"][0].decode("ascii") upload_name = get_filename_from_headers(headers) @@ -402,7 +397,7 @@ class MediaRepository(object): logger.info("Stored remote media in file %r", fname) - yield self.store.store_cached_remote_media( + await self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, @@ -420,7 +415,7 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_thumbnails(server_name, media_id, file_id, media_type) + await self._generate_thumbnails(server_name, media_id, file_id, media_type) return media_info @@ -455,16 +450,15 @@ class MediaRepository(object): return t_byte_source - @defer.inlineCallbacks - def generate_local_exact_thumbnail( + async def generate_local_exact_thumbnail( self, media_id, t_width, t_height, t_method, t_type, url_cache ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -487,7 +481,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -497,22 +491,21 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return output_path - @defer.inlineCallbacks - def generate_remote_exact_thumbnail( + async def generate_remote_exact_thumbnail( self, server_name, file_id, media_id, t_width, t_height, t_method, t_type ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=False) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -534,7 +527,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -544,7 +537,7 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -557,8 +550,7 @@ class MediaRepository(object): return output_path - @defer.inlineCallbacks - def _generate_thumbnails( + async def _generate_thumbnails( self, server_name, media_id, file_id, media_type, url_cache=False ): """Generate and store thumbnails for an image. @@ -579,7 +571,7 @@ class MediaRepository(object): if not requirements: return - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) ) @@ -597,13 +589,13 @@ class MediaRepository(object): return if thumbnailer.transpose_method is not None: - m_width, m_height = yield defer_to_thread( + m_width, m_height = await defer_to_thread( self.hs.get_reactor(), thumbnailer.transpose ) # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. - thumbnails = {} + thumbnails = {} # type: Dict[Tuple[int, int, str], str] for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) @@ -617,11 +609,11 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: @@ -643,7 +635,7 @@ class MediaRepository(object): url_cache=url_cache, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -653,7 +645,7 @@ class MediaRepository(object): # Write to database if server_name: - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -664,15 +656,14 @@ class MediaRepository(object): t_len, ) else: - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return {"width": m_width, "height": m_height} - @defer.inlineCallbacks - def delete_old_remote_media(self, before_ts): - old_media = yield self.store.get_remote_media_before(before_ts) + async def delete_old_remote_media(self, before_ts): + old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -686,12 +677,12 @@ class MediaRepository(object): # TODO: Should we delete from the backup store - with (yield self.remote_media_linearizer.queue(key)): + with (await self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) except OSError as e: - logger.warn("Failed to remove file: %r", full_path) + logger.warning("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: @@ -702,7 +693,7 @@ class MediaRepository(object): ) shutil.rmtree(thumbnail_dir, ignore_errors=True) - yield self.store.delete_remote_media(origin, media_id) + await self.store.delete_remote_media(origin, media_id) deleted += 1 return {"deleted": deleted} diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 3b87717a5a..683a79c966 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -148,6 +148,7 @@ class MediaStorage(object): for provider in self.storage_providers: res = yield provider.fetch(path, file_info) if res: + logger.debug("Streaming %s from %s", path, provider) return res return None diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 7a56cd4b6c..f206605727 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -23,6 +23,7 @@ import re import shutil import sys import traceback +from typing import Dict, Optional import six from six import string_types @@ -56,6 +57,9 @@ logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) +OG_TAG_NAME_MAXLEN = 50 +OG_TAG_VALUE_MAXLEN = 1000 + class PreviewUrlResource(DirectServeResource): isLeaf = True @@ -74,12 +78,15 @@ class PreviewUrlResource(DirectServeResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, + http_proxy=os.getenvb(b"http_proxy"), + https_proxy=os.getenvb(b"HTTPS_PROXY"), ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist + self.url_preview_accept_language = hs.config.url_preview_accept_language # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata @@ -117,8 +124,10 @@ class PreviewUrlResource(DirectServeResource): pattern = entry[attrib] value = getattr(url_tuple, attrib) logger.debug( - ("Matching attrib '%s' with value '%s' against" " pattern '%s'") - % (attrib, value, pattern) + "Matching attrib '%s' with value '%s' against pattern '%s'", + attrib, + value, + pattern, ) if value is None: @@ -134,7 +143,7 @@ class PreviewUrlResource(DirectServeResource): match = False continue if match: - logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) + logger.warning("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) @@ -157,8 +166,7 @@ class PreviewUrlResource(DirectServeResource): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - @defer.inlineCallbacks - def _do_preview(self, url, user, ts): + async def _do_preview(self, url, user, ts): """Check the db, and download the URL and build a preview Args: @@ -167,11 +175,11 @@ class PreviewUrlResource(DirectServeResource): ts (int): Returns: - Deferred[str]: json-encoded og data + Deferred[bytes]: json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) - cache_result = yield self.store.get_url_cache(url, ts) + cache_result = await self.store.get_url_cache(url, ts) if ( cache_result and cache_result["expires_ts"] > ts @@ -184,13 +192,13 @@ class PreviewUrlResource(DirectServeResource): og = og.encode("utf8") return og - media_info = yield self._download_url(url, user) + media_info = await self._download_url(url, user) - logger.debug("got media_info of '%s'" % media_info) + logger.debug("got media_info of '%s'", media_info) if _is_media(media_info["media_type"]): file_id = media_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, media_info["media_type"], url_cache=True ) @@ -206,7 +214,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % url) + logger.warning("Couldn't get dims for %s" % url) # define our OG response for this media elif _is_html(media_info["media_type"]): @@ -230,8 +238,8 @@ class PreviewUrlResource(DirectServeResource): # If we don't find a match, we'll look at the HTTP Content-Type, and # if that doesn't exist, we'll fall back to UTF-8. if not encoding: - match = _content_type_match.match(media_info["media_type"]) - encoding = match.group(1) if match else "utf-8" + content_match = _content_type_match.match(media_info["media_type"]) + encoding = content_match.group(1) if content_match else "utf-8" og = decode_and_calc_og(body, media_info["uri"], encoding) @@ -240,21 +248,21 @@ class PreviewUrlResource(DirectServeResource): # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. if "og:image" in og and og["og:image"]: - image_info = yield self._download_url( + image_info = await self._download_url( _rebase_url(og["og:image"], media_info["uri"]), user ) if _is_media(image_info["media_type"]): # TODO: make sure we don't choke on white-on-transparent images file_id = image_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, image_info["media_type"], url_cache=True ) if dims: og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % og["og:image"]) + logger.warning("Couldn't get dims for %s", og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, @@ -265,15 +273,27 @@ class PreviewUrlResource(DirectServeResource): else: del og["og:image"] else: - logger.warn("Failed to find any OG data in %s", url) + logger.warning("Failed to find any OG data in %s", url) og = {} - logger.debug("Calculated OG for %s as %s" % (url, og)) + # filter out any stupidly long values + keys_to_remove = [] + for k, v in og.items(): + # values can be numeric as well as strings, hence the cast to str + if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN: + logger.warning( + "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN] + ) + keys_to_remove.append(k) + for k in keys_to_remove: + del og[k] + + logger.debug("Calculated OG for %s as %s", url, og) - jsonog = json.dumps(og).encode("utf8") + jsonog = json.dumps(og) # store OG in history-aware DB cache - yield self.store.store_url_cache( + await self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], @@ -283,10 +303,9 @@ class PreviewUrlResource(DirectServeResource): media_info["created_ts"], ) - return jsonog + return jsonog.encode("utf8") - @defer.inlineCallbacks - def _download_url(self, url, user): + async def _download_url(self, url, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -297,9 +316,12 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: - logger.debug("Trying to get url '%s'" % url) - length, headers, uri, code = yield self.client.get_file( - url, output_stream=f, max_size=self.max_spider_size + logger.debug("Trying to get preview for url '%s'", url) + length, headers, uri, code = await self.client.get_file( + url, + output_stream=f, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, ) except SynapseError: # Pass SynapseErrors through directly, so that the servlet @@ -317,7 +339,7 @@ class PreviewUrlResource(DirectServeResource): ) except Exception as e: # FIXME: pass through 404s and other error messages nicely - logger.warn("Error downloading %s: %r", url, e) + logger.warning("Error downloading %s: %r", url, e) raise SynapseError( 500, @@ -325,7 +347,7 @@ class PreviewUrlResource(DirectServeResource): % (traceback.format_exception_only(sys.exc_info()[0], e),), Codes.UNKNOWN, ) - yield finish() + await finish() try: if b"Content-Type" in headers: @@ -336,7 +358,7 @@ class PreviewUrlResource(DirectServeResource): download_name = get_filename_from_headers(headers) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=file_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -373,22 +395,21 @@ class PreviewUrlResource(DirectServeResource): "expire_url_cache_data", self._expire_url_cache_data ) - @defer.inlineCallbacks - def _expire_url_cache_data(self): + async def _expire_url_cache_data(self): """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store now = self.clock.time_msec() - logger.info("Running url preview cache expiry") + logger.debug("Running url preview cache expiry") - if not (yield self.store.has_completed_background_updates()): + if not (await self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return # First we delete expired url cache entries - media_ids = yield self.store.get_expired_url_cache(now) + media_ids = await self.store.get_expired_url_cache(now) removed_media = [] for media_id in media_ids: @@ -398,7 +419,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) @@ -410,17 +431,19 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache(removed_media) + await self.store.delete_url_cache(removed_media) if removed_media: logger.info("Deleted %d entries from url cache", len(removed_media)) + else: + logger.debug("No entries removed from url cache") # Now we delete old images associated with the url cache. # These may be cached for a bit on the client (i.e., they # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. expire_before = now - 2 * 24 * 60 * 60 * 1000 - media_ids = yield self.store.get_url_cache_media_before(expire_before) + media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] for media_id in media_ids: @@ -430,7 +453,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue try: @@ -446,7 +469,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) @@ -458,9 +481,12 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache_media(removed_media) + await self.store.delete_url_cache_media(removed_media) - logger.info("Deleted %d media from url cache", len(removed_media)) + if removed_media: + logger.info("Deleted %d media from url cache", len(removed_media)) + else: + logger.debug("No media removed from url cache") def decode_and_calc_og(body, media_uri, request_encoding=None): @@ -499,9 +525,13 @@ def _calc_og(tree, media_uri): # "og:video:height" : "720", # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - og = {} + og = {} # type: Dict[str, Optional[str]] for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): if "content" in tag.attrib: + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 37687ea7f4..858680be26 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -77,6 +77,9 @@ class StorageProviderWrapper(StorageProvider): self.store_synchronous = store_synchronous self.store_remote = store_remote + def __str__(self): + return "StorageProviderWrapper[%s]" % (self.backend,) + def store_file(self, path, file_info): if not file_info.server_name and not self.store_local: return defer.succeed(None) @@ -114,6 +117,9 @@ class FileStorageProviderBackend(StorageProvider): self.cache_directory = hs.config.media_store_path self.base_directory = config + def __str__(self): + return "FileStorageProviderBackend[%s]" % (self.base_directory,) + def store_file(self, path, file_info): """See StorageProvider.store_file""" diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 08329884ac..0b87220234 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.server import ( DirectServeResource, set_cors_headers, @@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource): ) self.media_repo.mark_recently_accessed(server_name, media_id) - @defer.inlineCallbacks - def _respond_local_thumbnail( + async def _respond_local_thumbnail( self, request, media_id, width, height, method, m_type ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: thumbnail_info = self._select_thumbnail( @@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Couldn't find any generated thumbnails") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_local_thumbnail( + async def _select_or_generate_local_thumbnail( self, request, media_id, @@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width t_h = info["thumbnail_height"] == desired_height @@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_local_exact_thumbnail( + file_path = await self.media_repo.generate_local_exact_thumbnail( media_id, desired_width, desired_height, @@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_remote_thumbnail( + async def _select_or_generate_remote_thumbnail( self, request, server_name, @@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_remote_exact_thumbnail( + file_path = await self.media_repo.generate_remote_exact_thumbnail( server_name, file_id, media_id, @@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _respond_remote_thumbnail( + async def _respond_remote_thumbnail( self, request, server_name, media_id, width, height, method, m_type ): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Failed to find any generated thumbnails") respond_404(request) @@ -296,8 +290,8 @@ class ThumbnailResource(DirectServeResource): d_h = desired_height if desired_method.lower() == "crop": - info_list = [] - info_list2 = [] + crop_info_list = [] + crop_info_list2 = [] for info in thumbnail_infos: t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] @@ -309,7 +303,7 @@ class ThumbnailResource(DirectServeResource): type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] if t_w >= d_w or t_h >= d_h: - info_list.append( + crop_info_list.append( ( aspect_quality, min_quality, @@ -320,7 +314,7 @@ class ThumbnailResource(DirectServeResource): ) ) else: - info_list2.append( + crop_info_list2.append( ( aspect_quality, min_quality, @@ -330,10 +324,10 @@ class ThumbnailResource(DirectServeResource): info, ) ) - if info_list: - return min(info_list)[-1] + if crop_info_list: + return min(crop_info_list)[-1] else: - return min(info_list2)[-1] + return min(crop_info_list2)[-1] else: info_list = [] info_list2 = [] diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index c995d7e043..c234ea7421 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -82,13 +82,21 @@ class Thumbnailer(object): else: return (max_height * self.width) // self.height, max_height + def _resize(self, width, height): + # 1-bit or 8-bit color palette images need converting to RGB + # otherwise they will be scaled using nearest neighbour which + # looks awful + if self.image.mode in ["1", "P"]: + self.image = self.image.convert("RGB") + return self.image.resize((width, height), Image.ANTIALIAS) + def scale(self, width, height, output_type): """Rescales the image to the given dimensions. Returns: BytesIO: the bytes of the encoded image ready to be written to disk """ - scaled = self.image.resize((width, height), Image.ANTIALIAS) + scaled = self._resize(width, height) return self._encode_image(scaled, output_type) def crop(self, width, height, output_type): @@ -107,13 +115,13 @@ class Thumbnailer(object): """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width - scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS) + scaled_image = self._resize(width, scaled_height) crop_top = (scaled_height - height) // 2 crop_bottom = height + crop_top cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) else: scaled_width = (height * self.width) // self.height - scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS) + scaled_image = self._resize(scaled_width, height) crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) @@ -121,5 +129,8 @@ class Thumbnailer(object): def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() - output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) + fmt = self.FORMATS[output_type] + if fmt == "JPEG": + output_image = output_image.convert("RGB") + output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 5d76bbdf68..83d005812d 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -17,7 +17,7 @@ import logging from twisted.web.server import NOT_DONE_YET -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.http.server import ( DirectServeResource, respond_with_json, @@ -56,7 +56,11 @@ class UploadResource(DirectServeResource): if content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) if int(content_length) > self.max_upload_size: - raise SynapseError(msg="Upload request body is too large", code=413) + raise SynapseError( + msg="Upload request body is too large", + code=413, + errcode=Codes.TOO_LARGE, + ) upload_name = parse_string(request, b"filename", encoding=None) if upload_name: diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py new file mode 100644 index 0000000000..d958dd65bb --- /dev/null +++ b/synapse/rest/oidc/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.web.resource import Resource + +from synapse.rest.oidc.callback_resource import OIDCCallbackResource + +logger = logging.getLogger(__name__) + + +class OIDCResource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"callback", OIDCCallbackResource(hs)) diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py new file mode 100644 index 0000000000..c03194f001 --- /dev/null +++ b/synapse/rest/oidc/callback_resource.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.http.server import DirectServeResource, wrap_html_request_handler + +logger = logging.getLogger(__name__) + + +class OIDCCallbackResource(DirectServeResource): + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + @wrap_html_request_handler + async def _async_render_GET(self, request): + return await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 69ecc5e4b4..75e58043b4 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.python import failure -from synapse.http.server import DirectServeResource, wrap_html_request_handler +from synapse.api.errors import SynapseError +from synapse.http.server import DirectServeResource, return_html_error class SAML2ResponseResource(DirectServeResource): @@ -25,7 +27,21 @@ class SAML2ResponseResource(DirectServeResource): def __init__(self, hs): super().__init__() self._saml_handler = hs.get_saml_handler() + self._error_html_template = hs.config.saml2.saml2_error_html_template + + async def _async_render_GET(self, request): + # We're not expecting any GET request on that resource if everything goes right, + # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. + # In this case, just tell the user that something went wrong and they should + # try to authenticate again. + f = failure.Failure( + SynapseError(400, "Unexpected GET request on /saml2/authn_response") + ) + return_html_error(f, request, self._error_html_template) - @wrap_html_request_handler async def _async_render_POST(self, request): - return await self._saml_handler.handle_saml_response(request) + try: + await self._saml_handler.handle_saml_response(request) + except Exception: + f = failure.Failure() + return_html_error(f, request, self._error_html_template) diff --git a/synapse/server.py b/synapse/server.py index 1fcc7375d3..fe94836a2c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,17 +23,18 @@ # Imports required for the default HomeServer() implementation import abc import logging +import os -from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail -from twisted.web.client import BrowserLikePolicyForHTTPS from synapse.api.auth import Auth from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory +from synapse.crypto.context_factory import RegularPolicyForHTTPS from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.events.spamcheck import SpamChecker @@ -49,22 +50,24 @@ from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.sender import FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer -from synapse.groups.groups_server import GroupsServerHandler +from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler from synapse.handlers import Handlers from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator +from synapse.handlers.cas_handler import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.events import EventHandler, EventStreamHandler -from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler +from synapse.handlers.password_policy import PasswordPolicyHandler from synapse.handlers.presence import PresenceHandler from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler @@ -84,6 +87,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.handler import ReplicationCommandHandler +from synapse.replication.tcp.resource import ReplicationStreamer +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -95,9 +102,11 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler +from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor +from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -130,7 +139,6 @@ class HomeServer(object): DEPENDENCIES = [ "http_client", - "db_pool", "federation_client", "federation_server", "handlers", @@ -167,6 +175,7 @@ class HomeServer(object): "filtering", "http_client_context_factory", "simple_http_client", + "proxied_http_client", "media_repository", "media_repository_resource", "federation_transport_client", @@ -194,8 +203,15 @@ class HomeServer(object): "sendmail", "registration_handler", "account_validity_handler", + "cas_handler", "saml_handler", + "oidc_handler", "event_client_serializer", + "password_policy_handler", + "storage", + "replication_streamer", + "replication_data_handler", + "replication_streams", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -205,36 +221,60 @@ class HomeServer(object): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname, reactor=None, **kwargs): + def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. + config: The full config for the homeserver. """ if not reactor: from twisted.internet import reactor self._reactor = reactor self.hostname = hostname + self.config = config self._building = {} self._listening_services = [] + self.start_time = None + + self._instance_id = random_string(5) + self._instance_name = config.worker_name or "master" self.clock = Clock(reactor) self.distributor = Distributor() - self.ratelimiter = Ratelimiter() - self.admin_redaction_ratelimiter = Ratelimiter() - self.registration_ratelimiter = Ratelimiter() - self.datastore = None + self.registration_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=config.rc_registration.per_second, + burst_count=config.rc_registration.burst_count, + ) + + self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname]) + def get_instance_id(self): + """A unique ID for this synapse process instance. + + This is used to distinguish running instances in worker-based + deployments. + """ + return self._instance_id + + def get_instance_name(self) -> str: + """A unique name for this synapse process. + + Used to identify the process over replication and in config. Does not + change over restarts. + """ + return self._instance_name + def setup(self): logger.info("Setting up.") - with self.get_db_conn() as conn: - self.datastore = self.DATASTORE_CLASS(conn, self) - conn.commit() + self.start_time = int(self.get_clock().time()) + self.datastores = DataStores(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): @@ -266,7 +306,10 @@ class HomeServer(object): return self.clock def get_datastore(self): - return self.datastore + return self.datastores.main + + def get_datastores(self): + return self.datastores def get_config(self): return self.config @@ -274,15 +317,9 @@ class HomeServer(object): def get_distributor(self): return self.distributor - def get_ratelimiter(self): - return self.ratelimiter - - def get_registration_ratelimiter(self): + def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self): - return self.admin_redaction_ratelimiter - def build_federation_client(self): return FederationClient(self) @@ -302,12 +339,19 @@ class HomeServer(object): return ( InsecureInterceptableContextFactory() if self.config.use_insecure_ssl_client_just_for_testing_do_not_use - else BrowserLikePolicyForHTTPS() + else RegularPolicyForHTTPS() ) def build_simple_http_client(self): return SimpleHttpClient(self) + def build_proxied_http_client(self): + return SimpleHttpClient( + self, + http_proxy=os.getenvb(b"http_proxy"), + https_proxy=os.getenvb(b"HTTPS_PROXY"), + ) + def build_room_creation_handler(self): return RoomCreationHandler(self) @@ -405,36 +449,11 @@ class HomeServer(object): return PusherPool(self) def build_http_client(self): - tls_client_options_factory = context_factory.ClientTLSOptionsFactory( + tls_client_options_factory = context_factory.FederationPolicyForHTTPS( self.config ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) - ) - - def get_db_conn(self, run_new_connection=True): - """Makes a new connection to the database, skipping the db pool - - Returns: - Connection: a connection object implementing the PEP-249 spec - """ - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v - for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def build_media_repository_resource(self): # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of @@ -461,7 +480,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - raise NotImplementedError() + return ReplicationCommandHandler(self) def build_action_generator(self): return ActionGenerator(self) @@ -470,10 +489,16 @@ class HomeServer(object): return UserDirectoryHandler(self) def build_groups_local_handler(self): - return GroupsLocalHandler(self) + if self.config.worker_app: + return GroupsLocalWorkerHandler(self) + else: + return GroupsLocalHandler(self) def build_groups_server_handler(self): - return GroupsServerHandler(self) + if self.config.worker_app: + return GroupsServerWorkerHandler(self) + else: + return GroupsServerHandler(self) def build_groups_attestation_signing(self): return GroupAttestationSigning(self) @@ -529,14 +554,37 @@ class HomeServer(object): def build_account_validity_handler(self): return AccountValidityHandler(self) + def build_cas_handler(self): + return CasHandler(self) + def build_saml_handler(self): from synapse.handlers.saml_handler import SamlHandler return SamlHandler(self) + def build_oidc_handler(self): + from synapse.handlers.oidc_handler import OidcHandler + + return OidcHandler(self) + def build_event_client_serializer(self): return EventClientSerializer(self) + def build_password_policy_handler(self): + return PasswordPolicyHandler(self) + + def build_storage(self) -> Storage: + return Storage(self, self.datastores) + + def build_replication_streamer(self) -> ReplicationStreamer: + return ReplicationStreamer(self) + + def build_replication_data_handler(self): + return ReplicationDataHandler(self) + + def build_replication_streams(self): + return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) @@ -558,24 +606,22 @@ def _make_dependency_method(depname): try: builder = getattr(hs, "build_%s" % (depname)) except AttributeError: - builder = None + raise NotImplementedError( + "%s has no %s nor a builder for it" % (type(hs).__name__, depname) + ) - if builder: - # Prevent cyclic dependencies from deadlocking - if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % (depname,)) - hs._building[depname] = 1 + # Prevent cyclic dependencies from deadlocking + if depname in hs._building: + raise ValueError("Cyclic dependency while building %s" % (depname,)) + hs._building[depname] = 1 + try: dep = builder() setattr(hs, depname, dep) - + finally: del hs._building[depname] - return dep - - raise NotImplementedError( - "%s has no %s nor a builder for it" % (type(hs).__name__, depname) - ) + return dep setattr(HomeServer, "get_%s" % (depname), _get) diff --git a/synapse/server.pyi b/synapse/server.pyi index 16f8f6b573..fe8024d2d4 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,7 +1,12 @@ +from typing import Dict + +import twisted.internet + import synapse.api.auth import synapse.config.homeserver +import synapse.crypto.keyring +import synapse.federation.federation_server import synapse.federation.sender -import synapse.federation.transaction_queue import synapse.federation.transport.client import synapse.handlers import synapse.handlers.auth @@ -9,19 +14,31 @@ import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys import synapse.handlers.message +import synapse.handlers.presence +import synapse.handlers.register import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password +import synapse.http.client +import synapse.notifier +import synapse.push.pusherpool +import synapse.replication.tcp.client +import synapse.replication.tcp.handler import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage +from synapse.events.builder import EventBuilderFactory +from synapse.replication.tcp.streams import Stream class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: pass + @property + def hostname(self) -> str: + pass def get_auth(self) -> synapse.api.auth.Auth: pass def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: @@ -38,8 +55,16 @@ class HomeServer(object): pass def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass + def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + or support any HTTP_PROXY settings""" + pass + def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient: + """Fetch an HTTP client implementation which doesn't do any blacklisting + but does support HTTP_PROXY settings""" + pass def get_deactivate_account_handler( - self + self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: @@ -47,32 +72,72 @@ class HomeServer(object): def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass def get_event_creation_handler( - self + self, ) -> synapse.handlers.message.EventCreationHandler: pass def get_set_password_handler( - self + self, ) -> synapse.handlers.set_password.SetPasswordHandler: pass def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass def get_federation_transport_client( - self + self, ) -> synapse.federation.transport.client.TransportLayerClient: pass def get_media_repository_resource( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass def get_media_repository( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass def get_server_notices_manager( - self + self, ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass def get_server_notices_sender( - self + self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass + def get_notifier(self) -> synapse.notifier.Notifier: + pass + def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler: + pass + def get_clock(self) -> synapse.util.Clock: + pass + def get_reactor(self) -> twisted.internet.base.ReactorBase: + pass + def get_keyring(self) -> synapse.crypto.keyring.Keyring: + pass + def get_tcp_replication( + self, + ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: + pass + def get_replication_data_handler( + self, + ) -> synapse.replication.tcp.client.ReplicationDataHandler: + pass + def get_federation_registry( + self, + ) -> synapse.federation.federation_server.FederationHandlerRegistry: + pass + def is_mine_id(self, domain_id: str) -> bool: + pass + def get_instance_id(self) -> str: + pass + def get_instance_name(self) -> str: + pass + def get_event_builder_factory(self) -> EventBuilderFactory: + pass + def get_storage(self) -> synapse.storage.Storage: + pass + def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler: + pass + def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator: + pass + def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: + pass + def get_replication_streams(self) -> Dict[str, Stream]: + pass diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 415e9c17d8..3bf330da49 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError @@ -54,13 +52,12 @@ class ConsentServerNotices(object): ) if "body" not in self._server_notice_content: raise ConfigError( - "user_consent server_notice_consent must contain a 'body' " "key." + "user_consent server_notice_consent must contain a 'body' key." ) self._consent_uri_builder = ConsentURIBuilder(hs.config) - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, and does so if so Args: @@ -78,7 +75,7 @@ class ConsentServerNotices(object): return self._users_in_progress.add(user_id) try: - u = yield self._store.get_user_by_id(user_id) + u = await self._store.get_user_by_id(user_id) if u["is_guest"] and not self._send_to_guests: # don't send to guests @@ -100,8 +97,8 @@ class ConsentServerNotices(object): content = copy_with_str_subst( self._server_notice_content, {"consent_uri": consent_uri} ) - yield self._server_notices_manager.send_notice(user_id, content) - yield self._store.user_set_consent_server_notice_sent( + await self._server_notices_manager.send_notice(user_id, content) + await self._store.user_set_consent_server_notice_sent( user_id, self._current_consent_version ) except SynapseError as e: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 81c4aff496..73f2cedb5c 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -16,10 +16,9 @@ import logging from six import iteritems -from twisted.internet import defer - from synapse.api.constants import ( EventTypes, + LimitBlockingTypes, ServerNoticeLimitReached, ServerNoticeMsgType, ) @@ -49,8 +48,13 @@ class ResourceLimitsServerNotices(object): self._notifier = hs.get_notifier() - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + self._enabled = ( + hs.config.limit_usage_by_mau + and self._server_notices_manager.is_enabled() + and not hs.config.hs_disabled + ) + + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, this will be true in two cases. 1. The server has reached its limit does not reflect this @@ -63,77 +67,104 @@ class ResourceLimitsServerNotices(object): Returns: Deferred """ - if self._config.hs_disabled is True: - return - - if self._config.limit_usage_by_mau is False: + if not self._enabled: return - if not self._server_notices_manager.is_enabled(): - # Don't try and send server notices unles they've been enabled - return - - timestamp = yield self._store.user_last_seen_monthly_active(user_id) + timestamp = await self._store.user_last_seen_monthly_active(user_id) if timestamp is None: # This user will be blocked from receiving the notice anyway. # In practice, not sure we can ever get here return - # Determine current state of room - - room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) + room_id = await self._server_notices_manager.get_or_create_notice_room_for_user( + user_id + ) if not room_id: - logger.warn("Failed to get server notices room") + logger.warning("Failed to get server notices room") return - yield self._check_and_set_tags(user_id, room_id) - currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + await self._check_and_set_tags(user_id, room_id) + + # Determine current state of room + currently_blocked, ref_events = await self._is_room_currently_blocked(room_id) + limit_msg = None + limit_type = None try: - # Normally should always pass in user_id if you have it, but in - # this case are checking what would happen to other users if they - # were to arrive. - try: - yield self._auth.check_auth_blocking() - is_auth_blocking = False - except ResourceLimitError as e: - is_auth_blocking = True - event_content = e.msg - event_limit_type = e.limit_type - - if currently_blocked and not is_auth_blocking: - # Room is notifying of a block, when it ought not to be. - # Remove block notification - content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" - ) + # Normally should always pass in user_id to check_auth_blocking + # if you have it, but in this case are checking what would happen + # to other users if they were to arrive. + await self._auth.check_auth_blocking() + except ResourceLimitError as e: + limit_msg = e.msg + limit_type = e.limit_type - elif not currently_blocked and is_auth_blocking: + try: + if ( + limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER + and not self._config.mau_limit_alerting + ): + # We have hit the MAU limit, but MAU alerting is disabled: + # reset room if necessary and return + if currently_blocked: + await self._remove_limit_block_notification(user_id, ref_events) + return + + if currently_blocked and not limit_msg: + # Room is notifying of a block, when it ought not to be. + await self._remove_limit_block_notification(user_id, ref_events) + elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - # Add block notification - content = { - "body": event_content, - "msgtype": ServerNoticeMsgType, - "server_notice_type": ServerNoticeLimitReached, - "admin_contact": self._config.admin_contact, - "limit_type": event_limit_type, - } - event = yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Message - ) - - content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, "" + await self._apply_limit_block_notification( + user_id, limit_msg, limit_type ) - except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) - @defer.inlineCallbacks - def _check_and_set_tags(self, user_id, room_id): + async def _remove_limit_block_notification(self, user_id, ref_events): + """Utility method to remove limit block notifications from the server + notices room. + + Args: + user_id (str): user to notify + ref_events (list[str]): The event_ids of pinned events that are unrelated to + limit blocking and need to be preserved. + """ + content = {"pinned": ref_events} + await self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + + async def _apply_limit_block_notification( + self, user_id, event_body, event_limit_type + ): + """Utility method to apply limit block notifications in the server + notices room. + + Args: + user_id (str): user to notify + event_body(str): The human readable text that describes the block. + event_limit_type(str): Specifies the type of block e.g. monthly active user + limit has been exceeded. + """ + content = { + "body": event_body, + "msgtype": ServerNoticeMsgType, + "server_notice_type": ServerNoticeLimitReached, + "admin_contact": self._config.admin_contact, + "limit_type": event_limit_type, + } + event = await self._server_notices_manager.send_notice( + user_id, content, EventTypes.Message + ) + + content = {"pinned": [event.event_id]} + await self._server_notices_manager.send_notice( + user_id, content, EventTypes.Pinned, "" + ) + + async def _check_and_set_tags(self, user_id, room_id): """ Since server notices rooms were originally not with tags, important to check that tags have been set correctly @@ -141,20 +172,19 @@ class ResourceLimitsServerNotices(object): user_id(str): the user in question room_id(str): the server notices room for that user """ - tags = yield self._store.get_tags_for_room(user_id, room_id) + tags = await self._store.get_tags_for_room(user_id, room_id) need_to_set_tag = True if tags: if SERVER_NOTICE_ROOM_TAG in tags: # tag already present, nothing to do here need_to_set_tag = False if need_to_set_tag: - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) - @defer.inlineCallbacks - def _is_room_currently_blocked(self, room_id): + async def _is_room_currently_blocked(self, room_id): """ Determines if the room is currently blocked @@ -162,7 +192,7 @@ class ResourceLimitsServerNotices(object): room_id(str): The room id of the server notices room Returns: - + Deferred[Tuple[bool, List]]: bool: Is the room currently blocked list: The list of pinned events that are unrelated to limit blocking This list can be used as a convenience in the case where the block @@ -172,7 +202,7 @@ class ResourceLimitsServerNotices(object): currently_blocked = False pinned_state_event = None try: - pinned_state_event = yield self._state.get_current_state( + pinned_state_event = await self._state.get_current_state( room_id, event_type=EventTypes.Pinned ) except AuthError: @@ -183,7 +213,7 @@ class ResourceLimitsServerNotices(object): if pinned_state_event is not None: referenced_events = list(pinned_state_event.content.get("pinned", [])) - events = yield self._store.get_events(referenced_events) + events = await self._store.get_events(referenced_events) for event_id, event in iteritems(events): if event.type != EventTypes.Message: continue diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 2dac90578c..bf2454c01c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -14,11 +14,9 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership, RoomCreationPreset -from synapse.types import create_requester -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.types import UserID, create_requester +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -36,10 +34,12 @@ class ServerNoticesManager(object): self._store = hs.get_datastore() self._config = hs.config self._room_creation_handler = hs.get_room_creation_handler() + self._room_member_handler = hs.get_room_member_handler() self._event_creation_handler = hs.get_event_creation_handler() self._is_mine_id = hs.is_mine_id self._notifier = hs.get_notifier() + self.server_notices_mxid = self._config.server_notices_mxid def is_enabled(self): """Checks if server notices are enabled on this server. @@ -49,8 +49,7 @@ class ServerNoticesManager(object): """ return self._config.server_notices_mxid is not None - @defer.inlineCallbacks - def send_notice( + async def send_notice( self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -66,7 +65,8 @@ class ServerNoticesManager(object): Returns: Deferred[FrozenEvent] """ - room_id = yield self.get_notice_room_for_user(user_id) + room_id = await self.get_or_create_notice_room_for_user(user_id) + await self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid requester = create_requester(system_mxid) @@ -83,16 +83,17 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = yield self._event_creation_handler.create_and_send_nonmember_event( + event, _ = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) - return res + return event - @cachedInlineCallbacks() - def get_notice_room_for_user(self, user_id): + @cached() + async def get_or_create_notice_room_for_user(self, user_id): """Get the room for notices for a given user - If we have not yet created a notice room for this user, create it + If we have not yet created a notice room for this user, create it, but don't + invite the user to it. Args: user_id (str): complete user id for the user we want a room for @@ -105,21 +106,24 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) - system_mxid = self._config.server_notices_mxid for room in rooms: # it's worth noting that there is an asymmetry here in that we # expect the user to be invited or joined, but the system user must # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = yield self._store.get_users_in_room(room.room_id) - if system_mxid in user_ids: + user_ids = await self._store.get_users_in_room(room.room_id) + if self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user - logger.info("Using room %s", room.room_id) + logger.info( + "Using existing server notices room %s for user %s", + room.room_id, + user_id, + ) return room.room_id # apparently no existing notice room: create a new one @@ -138,24 +142,49 @@ class ServerNoticesManager(object): "avatar_url": self._config.server_notices_mxid_avatar_url, } - requester = create_requester(system_mxid) - info = yield self._room_creation_handler.create_room( + requester = create_requester(self.server_notices_mxid) + info, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, "name": self._config.server_notices_room_name, "power_level_content_override": {"users_default": -10}, - "invite": (user_id,), }, ratelimit=False, creator_join_profile=join_profile, ) room_id = info["room_id"] - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) logger.info("Created server notices room %s for %s", room_id, user_id) return room_id + + async def maybe_invite_user_to_room(self, user_id: str, room_id: str): + """Invite the given user to the given server room, unless the user has already + joined or been invited to it. + + Args: + user_id: The ID of the user to invite. + room_id: The ID of the room to invite the user to. + """ + requester = create_requester(self.server_notices_mxid) + + # Check whether the user has already joined or been invited to this room. If + # that's the case, there is no need to re-invite them. + joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE, Membership.JOIN] + ) + for room in joined_rooms: + if room.room_id == room_id: + return + + await self._room_member_handler.update_membership( + requester=requester, + target=UserID.from_string(user_id), + room_id=room_id, + action="invite", + ) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 652bab58e3..be74e86641 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, @@ -36,18 +34,16 @@ class ServerNoticesSender(object): ResourceLimitsServerNotices(hs), ) - @defer.inlineCallbacks - def on_user_syncing(self, user_id): + async def on_user_syncing(self, user_id): """Called when the user performs a sync operation. Args: user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) - @defer.inlineCallbacks - def on_user_ip(self, user_id): + async def on_user_ip(self, user_id): """Called on the master when a worker process saw a client request. Args: @@ -57,4 +53,4 @@ class ServerNoticesSender(object): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py new file mode 100644 index 0000000000..9b78924d96 --- /dev/null +++ b/synapse/spam_checker_api/__init__.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.storage.state import StateFilter + +MYPY = False +if MYPY: + import synapse.server + +logger = logging.getLogger(__name__) + + +class SpamCheckerApi(object): + """A proxy object that gets passed to spam checkers so they can get + access to rooms and other relevant information. + """ + + def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs + + self._store = hs.get_datastore() + + @defer.inlineCallbacks + def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred: + """Gets state events for the given room. + + Args: + room_id: The room ID to get state events in. + types: The event type and state key (using None + to represent 'any') of the room state to acquire. + + Returns: + twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: + The filtered state events in the room. + """ + state_ids = yield self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) + state = yield self._store.get_events(state_ids.values()) + return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 2b0f4c79ee..2fa529fcd0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,6 +16,7 @@ import logging from collections import namedtuple +from typing import Dict, Iterable, List, Optional, Set from six import iteritems, itervalues @@ -27,13 +28,15 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.types import StateMap from synapse.util.async_helpers import Linearizer -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.metrics import Measure +from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -49,7 +52,6 @@ state_groups_histogram = Histogram( KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -103,6 +105,7 @@ class StateHandler(object): def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -191,24 +194,37 @@ class StateHandler(object): return joined_users @defer.inlineCallbacks - def get_current_hosts_in_room(self, room_id, latest_event_ids=None): - if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - logger.debug("calling resolve_state_groups from get_current_hosts_in_room") - entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + def get_current_hosts_in_room(self, room_id): + event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + return (yield self.get_hosts_in_room_at_events(room_id, event_ids)) + + @defer.inlineCallbacks + def get_hosts_in_room_at_events(self, room_id, event_ids): + """Get the hosts that were in a room at the given event ids + + Args: + room_id (str): + event_ids (list[str]): + + Returns: + Deferred[list[str]]: the hosts in the room at the given events + """ + entry = yield self.resolve_state_groups_for_events(room_id, event_ids) joined_hosts = yield self.store.get_joined_hosts(room_id, entry) return joined_hosts @defer.inlineCallbacks - def compute_event_context(self, event, old_state=None): + def compute_event_context( + self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + ): """Build an EventContext structure for the event. This works out what the current state should be for the event, and generates a new state group if necessary. Args: - event (synapse.events.EventBase): - old_state (dict|None): The state at the event if it can't be + event: + old_state: The state at the event if it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. @@ -220,6 +236,9 @@ class StateHandler(object): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. + + # FIXME: why do we populate current_state_ids? I thought the point was + # that we weren't supposed to have any state for outliers? if old_state: prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): @@ -236,114 +255,105 @@ class StateHandler(object): # group for it. context = EventContext.with_state( state_group=None, + state_group_before_event=None, current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, ) return context + # + # first of all, figure out the state before the event + # + if old_state: - # We already have the state, so we don't need to calculate it. - # Let's just correctly fill out the context and create a - # new state group for it. - - prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} - - if event.is_state(): - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] - if replaces != event.event_id: # Paranoia check - event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - else: - current_state_ids = prev_state_ids + # if we're given the state before the event, then we use that + state_ids_before_event = { + (s.type, s.state_key): s.event_id for s in old_state + } + state_group_before_event = None + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None - state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=None, - delta_ids=None, - current_state_ids=current_state_ids, - ) + else: + # otherwise, we'll need to resolve the state across the prev_events. + logger.debug("calling resolve_state_groups from compute_event_context") - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + entry = yield self.resolve_state_groups_for_events( + event.room_id, event.prev_event_ids() ) - return context + state_ids_before_event = entry.state + state_group_before_event = entry.state_group + state_group_before_event_prev_group = entry.prev_group + deltas_to_state_group_before_event = entry.delta_ids - logger.debug("calling resolve_state_groups from compute_event_context") + # + # make sure that we have a state group at that point. If it's not a state event, + # that will be the state group for the new event. If it *is* a state event, + # it might get rejected (in which case we'll need to persist it with the + # previous state group) + # - entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids() - ) + if not state_group_before_event: + state_group_before_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) - prev_state_ids = entry.state - prev_group = None - delta_ids = None + # XXX: can we update the state cache entry for the new state group? or + # could we set a flag on resolve_state_groups_for_events to tell it to + # always make a state group? + + # + # now if it's not a state event, we're done + # + + if not event.is_state(): + return EventContext.with_state( + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + current_state_ids=state_ids_before_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + ) - if event.is_state(): - # If this is a state event then we need to create a new state - # group for the state after this event. + # + # otherwise, we'll need to create a new state group for after the event + # - key = (event.type, event.state_key) - if key in prev_state_ids: - replaces = prev_state_ids[key] + key = (event.type, event.state_key) + if key in state_ids_before_event: + replaces = state_ids_before_event[key] + if replaces != event.event_id: event.unsigned["replaces_state"] = replaces - current_state_ids = dict(prev_state_ids) - current_state_ids[key] = event.event_id - - if entry.state_group: - # If the state at the event has a state group assigned then - # we can use that as the prev group - prev_group = entry.state_group - delta_ids = {key: event.event_id} - elif entry.prev_group: - # If the state at the event only has a prev group, then we can - # use that as a prev group too. - prev_group = entry.prev_group - delta_ids = dict(entry.delta_ids) - delta_ids[key] = event.event_id - - state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=delta_ids, - current_state_ids=current_state_ids, - ) - else: - current_state_ids = prev_state_ids - prev_group = entry.prev_group - delta_ids = entry.delta_ids - - if entry.state_group is None: - entry.state_group = yield self.store.store_state_group( - event.event_id, - event.room_id, - prev_group=entry.prev_group, - delta_ids=entry.delta_ids, - current_state_ids=current_state_ids, - ) - entry.state_id = entry.state_group - - state_group = entry.state_group - - context = EventContext.with_state( - state_group=state_group, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=prev_group, + state_ids_after_event = dict(state_ids_before_event) + state_ids_after_event[key] = event.event_id + delta_ids = {key: event.event_id} + + state_group_after_event = yield self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, delta_ids=delta_ids, + current_state_ids=state_ids_after_event, ) - return context + return EventContext.with_state( + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + current_state_ids=state_ids_after_event, + prev_state_ids=state_ids_before_event, + prev_group=state_group_before_event, + delta_ids=delta_ids, + ) + @measure_func() @defer.inlineCallbacks def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each @@ -364,14 +374,16 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) + state_groups_ids = yield self.state_store.get_state_groups_ids( + room_id, event_ids + ) if len(state_groups_ids) == 0: return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.store.get_state_group_delta(name) + prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, @@ -380,7 +392,7 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) result = yield self._state_resolution_handler.resolve_state_groups( room_id, @@ -404,6 +416,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + event.room_id, room_version, state_set_ids, event_map=state_map, @@ -432,7 +445,7 @@ class StateResolutionHandler(object): self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, - max_len=SIZE_OF_CACHE, + max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, @@ -449,7 +462,7 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging) + room_id (str): room we are resolving for (used for logging and sanity checks) room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group @@ -505,6 +518,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + room_id, room_version, list(itervalues(state_groups_ids)), event_map=event_map, @@ -576,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids): ) -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", +): """ Args: - room_version(str): Version of the room + room_id: the room we are working in - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_version: Version of the room + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing events will be requested via state_map_factory. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: a place to fetch events from - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events + room_id, state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store + room_id, room_version, state_sets, event_map, state_res_store ) @@ -633,28 +655,21 @@ class StateResolutionStore(object): return self.store.get_events( event_ids, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=allow_rejected, ) - def get_auth_chain(self, event_ids): - """Gets the full auth chain for a set of events (including rejected - events). - - Includes the given event IDs in the result. + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). - Note that: - 1. All events must be state events. - 2. For v1 rooms this may not have the full auth chain in the - presence of rejected events - - Args: - event_ids (list): The event IDs of the events to fetch the auth - chain for. Must be state events. + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + Deferred[Set[str]]: Set of event IDs. """ - return self.store.get_auth_chain_ids(event_ids, include_given=True) + return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a2f92d9ff9..9bf98d06f2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,6 +15,7 @@ import hashlib import logging +from typing import Callable, Dict, List, Optional from six import iteritems, iterkeys, itervalues @@ -24,6 +25,8 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -32,13 +35,20 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_store(state_sets, event_map, state_map_factory): +def resolve_events_with_store( + room_id: str, + state_sets: List[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_map_factory: Callable, +): """ Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_id: the room we are working in + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,11 +56,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called + state_map_factory: will be called with a list of event_ids that are needed, and should return with a Deferred of dict of event_id to event. - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -59,9 +69,9 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): unconflicted_state, conflicted_state = _seperate(state_sets) - needed_events = set( + needed_events = { event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids - ) + } needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) @@ -76,6 +86,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if event_map is not None: state_map.update(event_map) + # everything in the state map should be in the right room + for event in state_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. # @@ -95,6 +113,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): ) state_map_new = yield state_map_factory(new_needed_events) + for event in state_map_new.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + state_map.update(state_map_new) return _resolve_with_state( @@ -236,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] + reverse = list(reversed(_ordered_events(events))) - auth_keys = set( + auth_keys = { key for event in events for key in event_auth.auth_types_for_event(event) - ) + } new_auth_events = {} for key in auth_keys: @@ -256,7 +281,7 @@ def _resolve_auth_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, @@ -274,7 +299,7 @@ def _resolve_normal_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, diff --git a/synapse/state/v2.py b/synapse/state/v2.py index b327c86f40..18484e2fa6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,29 +16,42 @@ import heapq import itertools import logging +from typing import Dict, List, Optional from six import iteritems, itervalues from twisted.internet import defer +import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @defer.inlineCallbacks -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "synapse.state.StateResolutionStore", +): """Resolves the state using the v2 state resolution algorithm Args: - room_version (str): The room version + room_id: the room we are working in + + room_version: The room version - state_sets(list): List of dicts of (type, state_key) -> event_id, + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,9 +59,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -84,7 +97,15 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) event_map.update(events) - full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + # everything in the event map should be in the right room + for event in event_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + + full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map} logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -94,13 +115,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, event_map, state_res_store, full_conflicted_set + room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( + room_id, room_version, sorted_power_events, unconflicted_state, @@ -121,13 +143,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store + room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, state_res_store + room_id, + room_version, + leftover_events, + resolved_state, + event_map, + state_res_store, ) logger.debug("resolved") @@ -141,11 +168,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto @defer.inlineCallbacks -def _get_power_level_for_sender(event_id, event_map, state_res_store): +def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -153,20 +181,24 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(event_id, event_map, state_res_store) + event = yield _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) - if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + aev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): pl = aev break if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) - if (aev.type, aev.state_key) == (EventTypes.Create, ""): + aev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev.content.get("creator") == event.sender: return 100 break @@ -195,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Returns: Deferred[set[str]]: Set of event IDs """ - common = set(itervalues(state_sets[0])).intersection( - *(itervalues(s) for s in state_sets[1:]) - ) - - auth_sets = [] - for state_set in state_sets: - auth_ids = set( - eid - for key, eid in iteritems(state_set) - if ( - key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) - or key - in ( - (EventTypes.PowerLevels, ""), - (EventTypes.Create, ""), - (EventTypes.JoinRules, ""), - ) - ) - and eid not in common - ) - auth_chain = yield state_res_store.get_auth_chain(auth_ids) - auth_ids.update(auth_chain) - - auth_sets.append(auth_ids) - - intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) - union = set().union(*auth_sets) + difference = yield state_res_store.get_auth_chain_difference( + [set(state_set.values()) for state_set in state_sets] + ) - return union - intersection + return difference def _seperate(state_sets): @@ -243,7 +251,7 @@ def _seperate(state_sets): conflicted_state = {} for key in set(itertools.chain.from_iterable(state_sets)): - event_ids = set(state_set.get(key) for state_set in state_sets) + event_ids = {state_set.get(key) for state_set in state_sets} if len(event_ids) == 1: unconflicted_state[key] = event_ids.pop() else: @@ -279,7 +287,7 @@ def _is_power_event(event): @defer.inlineCallbacks def _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -287,6 +295,7 @@ def _add_event_and_auth_chain_to_graph( Args: graph (dict[str, set[str]]): A map from event ID to the events auth event IDs + room_id (str): the room we are working in event_id (str): Event to add to the graph event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -298,7 +307,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(eid, event_map, state_res_store) + event = yield _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -308,11 +317,14 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks -def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): +def _reverse_topological_power_sort( + room_id, event_ids, event_map, state_res_store, auth_diff +): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -325,12 +337,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} for event_id in graph: - pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + pl = yield _get_power_level_for_sender( + room_id, event_id, event_map, state_res_store + ) event_to_pl[event_id] = pl def _get_power_order(event_id): @@ -348,44 +362,53 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ @defer.inlineCallbacks def _iterative_auth_checks( - room_version, event_ids, base_state, event_map, state_res_store + room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: + room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (dict[tuple[str, str], str]): The set of state to start with + base_state (StateMap[str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: - Deferred[dict[tuple[str, str], str]]: Returns the final updated state + Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for event_id in event_ids: event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) - if ev.rejected_reason is None: - auth_events[(ev.type, ev.state_key)] = ev + if not ev: + logger.warning( + "auth_event id %s for event %s is missing", aid, event_id + ) + else: + if ev.rejected_reason is None: + auth_events[(ev.type, ev.state_key)] = ev for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(ev_id, event_map, state_res_store) + ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] try: event_auth.check( - room_version, + room_version_obj, event, auth_events, do_sig_check=False, @@ -400,11 +423,14 @@ def _iterative_auth_checks( @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): +def _mainline_sort( + room_id, event_ids, resolved_power_event_id, event_map, state_res_store +): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID event_map (dict[str,FrozenEvent]) @@ -417,12 +443,14 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor pl = resolved_power_event_id while pl: mainline.append(pl) - pl_ev = yield _get_event(pl, event_map, state_res_store) + pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event(aid, event_map, state_res_store) - if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): + ev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break @@ -457,6 +485,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor Deferred[int] """ + room_id = event.room_id + # We do an iterative search, replacing `event with the power level in its # auth events (if any) while event: @@ -468,8 +498,10 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event(aid, event_map, state_res_store) - if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): + aev = yield _get_event( + room_id, aid, event_map, state_res_store, allow_none=True + ) + if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): event = aev break @@ -478,22 +510,37 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor @defer.inlineCallbacks -def _get_event(event_id, event_map, state_res_store): +def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): """Helper function to look up event in event_map, falling back to looking it up in the store Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) + allow_none (bool): if the event is not found, return None rather than raising + an exception Returns: - Deferred[FrozenEvent] + Deferred[Optional[FrozenEvent]] """ if event_id not in event_map: events = yield state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) - return event_map[event_id] + event = event_map.get(event_id) + + if event is None: + if allow_none: + return None + raise Exception("Unknown event %s" % (event_id,)) + + if event.room_id != room_id: + raise Exception( + "In state res for room %s, event %s is in %s" + % (room_id, event_id, event.room_id) + ) + return event def lexicographical_topological_sort(graph, key): diff --git a/synapse/static/client/login/index.html b/synapse/static/client/login/index.html index bcb6bc6bb7..6fefdaaff7 100644 --- a/synapse/static/client/login/index.html +++ b/synapse/static/client/login/index.html @@ -1,15 +1,16 @@ +<!doctype html> <html> <head> <title> Login </title> <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> <link rel="stylesheet" href="style.css"> -<script src="js/jquery-2.1.3.min.js"></script> +<script src="js/jquery-3.4.1.min.js"></script> <script src="js/login.js"></script> </head> <body onload="matrixLogin.onLoad()"> <center> <br/> - <h1>Log in with one of the following methods</h1> + <h1 id="title"></h1> <span id="feedback" style="color: #f00"></span> diff --git a/synapse/static/client/login/js/jquery-2.1.3.min.js b/synapse/static/client/login/js/jquery-2.1.3.min.js deleted file mode 100644 index 25714ed29a..0000000000 --- a/synapse/static/client/login/js/jquery-2.1.3.min.js +++ /dev/null @@ -1,4 +0,0 @@ -/*! jQuery v2.1.3 | (c) 2005, 2014 jQuery Foundation, Inc. | jquery.org/license */ -!function(a,b){"object"==typeof module&&"object"==typeof module.exports?module.exports=a.document?b(a,!0):function(a){if(!a.document)throw new Error("jQuery requires a window with a document");return b(a)}:b(a)}("undefined"!=typeof window?window:this,function(a,b){var c=[],d=c.slice,e=c.concat,f=c.push,g=c.indexOf,h={},i=h.toString,j=h.hasOwnProperty,k={},l=a.document,m="2.1.3",n=function(a,b){return new n.fn.init(a,b)},o=/^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g,p=/^-ms-/,q=/-([\da-z])/gi,r=function(a,b){return b.toUpperCase()};n.fn=n.prototype={jquery:m,constructor:n,selector:"",length:0,toArray:function(){return d.call(this)},get:function(a){return null!=a?0>a?this[a+this.length]:this[a]:d.call(this)},pushStack:function(a){var b=n.merge(this.constructor(),a);return b.prevObject=this,b.context=this.context,b},each:function(a,b){return n.each(this,a,b)},map:function(a){return this.pushStack(n.map(this,function(b,c){return a.call(b,c,b)}))},slice:function(){return this.pushStack(d.apply(this,arguments))},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},eq:function(a){var b=this.length,c=+a+(0>a?b:0);return this.pushStack(c>=0&&b>c?[this[c]]:[])},end:function(){return this.prevObject||this.constructor(null)},push:f,sort:c.sort,splice:c.splice},n.extend=n.fn.extend=function(){var a,b,c,d,e,f,g=arguments[0]||{},h=1,i=arguments.length,j=!1;for("boolean"==typeof g&&(j=g,g=arguments[h]||{},h++),"object"==typeof g||n.isFunction(g)||(g={}),h===i&&(g=this,h--);i>h;h++)if(null!=(a=arguments[h]))for(b in a)c=g[b],d=a[b],g!==d&&(j&&d&&(n.isPlainObject(d)||(e=n.isArray(d)))?(e?(e=!1,f=c&&n.isArray(c)?c:[]):f=c&&n.isPlainObject(c)?c:{},g[b]=n.extend(j,f,d)):void 0!==d&&(g[b]=d));return g},n.extend({expando:"jQuery"+(m+Math.random()).replace(/\D/g,""),isReady:!0,error:function(a){throw new Error(a)},noop:function(){},isFunction:function(a){return"function"===n.type(a)},isArray:Array.isArray,isWindow:function(a){return null!=a&&a===a.window},isNumeric:function(a){return!n.isArray(a)&&a-parseFloat(a)+1>=0},isPlainObject:function(a){return"object"!==n.type(a)||a.nodeType||n.isWindow(a)?!1:a.constructor&&!j.call(a.constructor.prototype,"isPrototypeOf")?!1:!0},isEmptyObject:function(a){var b;for(b in a)return!1;return!0},type:function(a){return null==a?a+"":"object"==typeof a||"function"==typeof a?h[i.call(a)]||"object":typeof a},globalEval:function(a){var b,c=eval;a=n.trim(a),a&&(1===a.indexOf("use strict")?(b=l.createElement("script"),b.text=a,l.head.appendChild(b).parentNode.removeChild(b)):c(a))},camelCase:function(a){return a.replace(p,"ms-").replace(q,r)},nodeName:function(a,b){return a.nodeName&&a.nodeName.toLowerCase()===b.toLowerCase()},each:function(a,b,c){var d,e=0,f=a.length,g=s(a);if(c){if(g){for(;f>e;e++)if(d=b.apply(a[e],c),d===!1)break}else for(e in a)if(d=b.apply(a[e],c),d===!1)break}else if(g){for(;f>e;e++)if(d=b.call(a[e],e,a[e]),d===!1)break}else for(e in a)if(d=b.call(a[e],e,a[e]),d===!1)break;return a},trim:function(a){return null==a?"":(a+"").replace(o,"")},makeArray:function(a,b){var c=b||[];return null!=a&&(s(Object(a))?n.merge(c,"string"==typeof a?[a]:a):f.call(c,a)),c},inArray:function(a,b,c){return null==b?-1:g.call(b,a,c)},merge:function(a,b){for(var c=+b.length,d=0,e=a.length;c>d;d++)a[e++]=b[d];return a.length=e,a},grep:function(a,b,c){for(var d,e=[],f=0,g=a.length,h=!c;g>f;f++)d=!b(a[f],f),d!==h&&e.push(a[f]);return e},map:function(a,b,c){var d,f=0,g=a.length,h=s(a),i=[];if(h)for(;g>f;f++)d=b(a[f],f,c),null!=d&&i.push(d);else for(f in a)d=b(a[f],f,c),null!=d&&i.push(d);return e.apply([],i)},guid:1,proxy:function(a,b){var c,e,f;return"string"==typeof b&&(c=a[b],b=a,a=c),n.isFunction(a)?(e=d.call(arguments,2),f=function(){return a.apply(b||this,e.concat(d.call(arguments)))},f.guid=a.guid=a.guid||n.guid++,f):void 0},now:Date.now,support:k}),n.each("Boolean Number String Function Array Date RegExp Object Error".split(" "),function(a,b){h["[object "+b+"]"]=b.toLowerCase()});function s(a){var b=a.length,c=n.type(a);return"function"===c||n.isWindow(a)?!1:1===a.nodeType&&b?!0:"array"===c||0===b||"number"==typeof b&&b>0&&b-1 in a}var t=function(a){var b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u="sizzle"+1*new Date,v=a.document,w=0,x=0,y=hb(),z=hb(),A=hb(),B=function(a,b){return a===b&&(l=!0),0},C=1<<31,D={}.hasOwnProperty,E=[],F=E.pop,G=E.push,H=E.push,I=E.slice,J=function(a,b){for(var c=0,d=a.length;d>c;c++)if(a[c]===b)return c;return-1},K="checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped",L="[\\x20\\t\\r\\n\\f]",M="(?:\\\\.|[\\w-]|[^\\x00-\\xa0])+",N=M.replace("w","w#"),O="\\["+L+"*("+M+")(?:"+L+"*([*^$|!~]?=)"+L+"*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|("+N+"))|)"+L+"*\\]",P=":("+M+")(?:\\((('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|((?:\\\\.|[^\\\\()[\\]]|"+O+")*)|.*)\\)|)",Q=new RegExp(L+"+","g"),R=new RegExp("^"+L+"+|((?:^|[^\\\\])(?:\\\\.)*)"+L+"+$","g"),S=new RegExp("^"+L+"*,"+L+"*"),T=new RegExp("^"+L+"*([>+~]|"+L+")"+L+"*"),U=new RegExp("="+L+"*([^\\]'\"]*?)"+L+"*\\]","g"),V=new RegExp(P),W=new RegExp("^"+N+"$"),X={ID:new RegExp("^#("+M+")"),CLASS:new RegExp("^\\.("+M+")"),TAG:new RegExp("^("+M.replace("w","w*")+")"),ATTR:new RegExp("^"+O),PSEUDO:new RegExp("^"+P),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+L+"*(even|odd|(([+-]|)(\\d*)n|)"+L+"*(?:([+-]|)"+L+"*(\\d+)|))"+L+"*\\)|)","i"),bool:new RegExp("^(?:"+K+")$","i"),needsContext:new RegExp("^"+L+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+L+"*((?:-\\d)?\\d*)"+L+"*\\)|)(?=[^-]|$)","i")},Y=/^(?:input|select|textarea|button)$/i,Z=/^h\d$/i,$=/^[^{]+\{\s*\[native \w/,_=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,ab=/[+~]/,bb=/'|\\/g,cb=new RegExp("\\\\([\\da-f]{1,6}"+L+"?|("+L+")|.)","ig"),db=function(a,b,c){var d="0x"+b-65536;return d!==d||c?b:0>d?String.fromCharCode(d+65536):String.fromCharCode(d>>10|55296,1023&d|56320)},eb=function(){m()};try{H.apply(E=I.call(v.childNodes),v.childNodes),E[v.childNodes.length].nodeType}catch(fb){H={apply:E.length?function(a,b){G.apply(a,I.call(b))}:function(a,b){var c=a.length,d=0;while(a[c++]=b[d++]);a.length=c-1}}}function gb(a,b,d,e){var f,h,j,k,l,o,r,s,w,x;if((b?b.ownerDocument||b:v)!==n&&m(b),b=b||n,d=d||[],k=b.nodeType,"string"!=typeof a||!a||1!==k&&9!==k&&11!==k)return d;if(!e&&p){if(11!==k&&(f=_.exec(a)))if(j=f[1]){if(9===k){if(h=b.getElementById(j),!h||!h.parentNode)return d;if(h.id===j)return d.push(h),d}else if(b.ownerDocument&&(h=b.ownerDocument.getElementById(j))&&t(b,h)&&h.id===j)return d.push(h),d}else{if(f[2])return H.apply(d,b.getElementsByTagName(a)),d;if((j=f[3])&&c.getElementsByClassName)return H.apply(d,b.getElementsByClassName(j)),d}if(c.qsa&&(!q||!q.test(a))){if(s=r=u,w=b,x=1!==k&&a,1===k&&"object"!==b.nodeName.toLowerCase()){o=g(a),(r=b.getAttribute("id"))?s=r.replace(bb,"\\$&"):b.setAttribute("id",s),s="[id='"+s+"'] ",l=o.length;while(l--)o[l]=s+rb(o[l]);w=ab.test(a)&&pb(b.parentNode)||b,x=o.join(",")}if(x)try{return H.apply(d,w.querySelectorAll(x)),d}catch(y){}finally{r||b.removeAttribute("id")}}}return i(a.replace(R,"$1"),b,d,e)}function hb(){var a=[];function b(c,e){return a.push(c+" ")>d.cacheLength&&delete b[a.shift()],b[c+" "]=e}return b}function ib(a){return a[u]=!0,a}function jb(a){var b=n.createElement("div");try{return!!a(b)}catch(c){return!1}finally{b.parentNode&&b.parentNode.removeChild(b),b=null}}function kb(a,b){var c=a.split("|"),e=a.length;while(e--)d.attrHandle[c[e]]=b}function lb(a,b){var c=b&&a,d=c&&1===a.nodeType&&1===b.nodeType&&(~b.sourceIndex||C)-(~a.sourceIndex||C);if(d)return d;if(c)while(c=c.nextSibling)if(c===b)return-1;return a?1:-1}function mb(a){return function(b){var c=b.nodeName.toLowerCase();return"input"===c&&b.type===a}}function nb(a){return function(b){var c=b.nodeName.toLowerCase();return("input"===c||"button"===c)&&b.type===a}}function ob(a){return ib(function(b){return b=+b,ib(function(c,d){var e,f=a([],c.length,b),g=f.length;while(g--)c[e=f[g]]&&(c[e]=!(d[e]=c[e]))})})}function pb(a){return a&&"undefined"!=typeof a.getElementsByTagName&&a}c=gb.support={},f=gb.isXML=function(a){var b=a&&(a.ownerDocument||a).documentElement;return b?"HTML"!==b.nodeName:!1},m=gb.setDocument=function(a){var b,e,g=a?a.ownerDocument||a:v;return g!==n&&9===g.nodeType&&g.documentElement?(n=g,o=g.documentElement,e=g.defaultView,e&&e!==e.top&&(e.addEventListener?e.addEventListener("unload",eb,!1):e.attachEvent&&e.attachEvent("onunload",eb)),p=!f(g),c.attributes=jb(function(a){return a.className="i",!a.getAttribute("className")}),c.getElementsByTagName=jb(function(a){return a.appendChild(g.createComment("")),!a.getElementsByTagName("*").length}),c.getElementsByClassName=$.test(g.getElementsByClassName),c.getById=jb(function(a){return o.appendChild(a).id=u,!g.getElementsByName||!g.getElementsByName(u).length}),c.getById?(d.find.ID=function(a,b){if("undefined"!=typeof b.getElementById&&p){var c=b.getElementById(a);return c&&c.parentNode?[c]:[]}},d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){return a.getAttribute("id")===b}}):(delete d.find.ID,d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){var c="undefined"!=typeof a.getAttributeNode&&a.getAttributeNode("id");return c&&c.value===b}}),d.find.TAG=c.getElementsByTagName?function(a,b){return"undefined"!=typeof b.getElementsByTagName?b.getElementsByTagName(a):c.qsa?b.querySelectorAll(a):void 0}:function(a,b){var c,d=[],e=0,f=b.getElementsByTagName(a);if("*"===a){while(c=f[e++])1===c.nodeType&&d.push(c);return d}return f},d.find.CLASS=c.getElementsByClassName&&function(a,b){return p?b.getElementsByClassName(a):void 0},r=[],q=[],(c.qsa=$.test(g.querySelectorAll))&&(jb(function(a){o.appendChild(a).innerHTML="<a id='"+u+"'></a><select id='"+u+"-\f]' msallowcapture=''><option selected=''></option></select>",a.querySelectorAll("[msallowcapture^='']").length&&q.push("[*^$]="+L+"*(?:''|\"\")"),a.querySelectorAll("[selected]").length||q.push("\\["+L+"*(?:value|"+K+")"),a.querySelectorAll("[id~="+u+"-]").length||q.push("~="),a.querySelectorAll(":checked").length||q.push(":checked"),a.querySelectorAll("a#"+u+"+*").length||q.push(".#.+[+~]")}),jb(function(a){var b=g.createElement("input");b.setAttribute("type","hidden"),a.appendChild(b).setAttribute("name","D"),a.querySelectorAll("[name=d]").length&&q.push("name"+L+"*[*^$|!~]?="),a.querySelectorAll(":enabled").length||q.push(":enabled",":disabled"),a.querySelectorAll("*,:x"),q.push(",.*:")})),(c.matchesSelector=$.test(s=o.matches||o.webkitMatchesSelector||o.mozMatchesSelector||o.oMatchesSelector||o.msMatchesSelector))&&jb(function(a){c.disconnectedMatch=s.call(a,"div"),s.call(a,"[s!='']:x"),r.push("!=",P)}),q=q.length&&new RegExp(q.join("|")),r=r.length&&new RegExp(r.join("|")),b=$.test(o.compareDocumentPosition),t=b||$.test(o.contains)?function(a,b){var c=9===a.nodeType?a.documentElement:a,d=b&&b.parentNode;return a===d||!(!d||1!==d.nodeType||!(c.contains?c.contains(d):a.compareDocumentPosition&&16&a.compareDocumentPosition(d)))}:function(a,b){if(b)while(b=b.parentNode)if(b===a)return!0;return!1},B=b?function(a,b){if(a===b)return l=!0,0;var d=!a.compareDocumentPosition-!b.compareDocumentPosition;return d?d:(d=(a.ownerDocument||a)===(b.ownerDocument||b)?a.compareDocumentPosition(b):1,1&d||!c.sortDetached&&b.compareDocumentPosition(a)===d?a===g||a.ownerDocument===v&&t(v,a)?-1:b===g||b.ownerDocument===v&&t(v,b)?1:k?J(k,a)-J(k,b):0:4&d?-1:1)}:function(a,b){if(a===b)return l=!0,0;var c,d=0,e=a.parentNode,f=b.parentNode,h=[a],i=[b];if(!e||!f)return a===g?-1:b===g?1:e?-1:f?1:k?J(k,a)-J(k,b):0;if(e===f)return lb(a,b);c=a;while(c=c.parentNode)h.unshift(c);c=b;while(c=c.parentNode)i.unshift(c);while(h[d]===i[d])d++;return d?lb(h[d],i[d]):h[d]===v?-1:i[d]===v?1:0},g):n},gb.matches=function(a,b){return gb(a,null,null,b)},gb.matchesSelector=function(a,b){if((a.ownerDocument||a)!==n&&m(a),b=b.replace(U,"='$1']"),!(!c.matchesSelector||!p||r&&r.test(b)||q&&q.test(b)))try{var d=s.call(a,b);if(d||c.disconnectedMatch||a.document&&11!==a.document.nodeType)return d}catch(e){}return gb(b,n,null,[a]).length>0},gb.contains=function(a,b){return(a.ownerDocument||a)!==n&&m(a),t(a,b)},gb.attr=function(a,b){(a.ownerDocument||a)!==n&&m(a);var e=d.attrHandle[b.toLowerCase()],f=e&&D.call(d.attrHandle,b.toLowerCase())?e(a,b,!p):void 0;return void 0!==f?f:c.attributes||!p?a.getAttribute(b):(f=a.getAttributeNode(b))&&f.specified?f.value:null},gb.error=function(a){throw new Error("Syntax error, unrecognized expression: "+a)},gb.uniqueSort=function(a){var b,d=[],e=0,f=0;if(l=!c.detectDuplicates,k=!c.sortStable&&a.slice(0),a.sort(B),l){while(b=a[f++])b===a[f]&&(e=d.push(f));while(e--)a.splice(d[e],1)}return k=null,a},e=gb.getText=function(a){var b,c="",d=0,f=a.nodeType;if(f){if(1===f||9===f||11===f){if("string"==typeof a.textContent)return a.textContent;for(a=a.firstChild;a;a=a.nextSibling)c+=e(a)}else if(3===f||4===f)return a.nodeValue}else while(b=a[d++])c+=e(b);return c},d=gb.selectors={cacheLength:50,createPseudo:ib,match:X,attrHandle:{},find:{},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(a){return a[1]=a[1].replace(cb,db),a[3]=(a[3]||a[4]||a[5]||"").replace(cb,db),"~="===a[2]&&(a[3]=" "+a[3]+" "),a.slice(0,4)},CHILD:function(a){return a[1]=a[1].toLowerCase(),"nth"===a[1].slice(0,3)?(a[3]||gb.error(a[0]),a[4]=+(a[4]?a[5]+(a[6]||1):2*("even"===a[3]||"odd"===a[3])),a[5]=+(a[7]+a[8]||"odd"===a[3])):a[3]&&gb.error(a[0]),a},PSEUDO:function(a){var b,c=!a[6]&&a[2];return X.CHILD.test(a[0])?null:(a[3]?a[2]=a[4]||a[5]||"":c&&V.test(c)&&(b=g(c,!0))&&(b=c.indexOf(")",c.length-b)-c.length)&&(a[0]=a[0].slice(0,b),a[2]=c.slice(0,b)),a.slice(0,3))}},filter:{TAG:function(a){var b=a.replace(cb,db).toLowerCase();return"*"===a?function(){return!0}:function(a){return a.nodeName&&a.nodeName.toLowerCase()===b}},CLASS:function(a){var b=y[a+" "];return b||(b=new RegExp("(^|"+L+")"+a+"("+L+"|$)"))&&y(a,function(a){return b.test("string"==typeof a.className&&a.className||"undefined"!=typeof a.getAttribute&&a.getAttribute("class")||"")})},ATTR:function(a,b,c){return function(d){var e=gb.attr(d,a);return null==e?"!="===b:b?(e+="","="===b?e===c:"!="===b?e!==c:"^="===b?c&&0===e.indexOf(c):"*="===b?c&&e.indexOf(c)>-1:"$="===b?c&&e.slice(-c.length)===c:"~="===b?(" "+e.replace(Q," ")+" ").indexOf(c)>-1:"|="===b?e===c||e.slice(0,c.length+1)===c+"-":!1):!0}},CHILD:function(a,b,c,d,e){var f="nth"!==a.slice(0,3),g="last"!==a.slice(-4),h="of-type"===b;return 1===d&&0===e?function(a){return!!a.parentNode}:function(b,c,i){var j,k,l,m,n,o,p=f!==g?"nextSibling":"previousSibling",q=b.parentNode,r=h&&b.nodeName.toLowerCase(),s=!i&&!h;if(q){if(f){while(p){l=b;while(l=l[p])if(h?l.nodeName.toLowerCase()===r:1===l.nodeType)return!1;o=p="only"===a&&!o&&"nextSibling"}return!0}if(o=[g?q.firstChild:q.lastChild],g&&s){k=q[u]||(q[u]={}),j=k[a]||[],n=j[0]===w&&j[1],m=j[0]===w&&j[2],l=n&&q.childNodes[n];while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if(1===l.nodeType&&++m&&l===b){k[a]=[w,n,m];break}}else if(s&&(j=(b[u]||(b[u]={}))[a])&&j[0]===w)m=j[1];else while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if((h?l.nodeName.toLowerCase()===r:1===l.nodeType)&&++m&&(s&&((l[u]||(l[u]={}))[a]=[w,m]),l===b))break;return m-=e,m===d||m%d===0&&m/d>=0}}},PSEUDO:function(a,b){var c,e=d.pseudos[a]||d.setFilters[a.toLowerCase()]||gb.error("unsupported pseudo: "+a);return e[u]?e(b):e.length>1?(c=[a,a,"",b],d.setFilters.hasOwnProperty(a.toLowerCase())?ib(function(a,c){var d,f=e(a,b),g=f.length;while(g--)d=J(a,f[g]),a[d]=!(c[d]=f[g])}):function(a){return e(a,0,c)}):e}},pseudos:{not:ib(function(a){var b=[],c=[],d=h(a.replace(R,"$1"));return d[u]?ib(function(a,b,c,e){var f,g=d(a,null,e,[]),h=a.length;while(h--)(f=g[h])&&(a[h]=!(b[h]=f))}):function(a,e,f){return b[0]=a,d(b,null,f,c),b[0]=null,!c.pop()}}),has:ib(function(a){return function(b){return gb(a,b).length>0}}),contains:ib(function(a){return a=a.replace(cb,db),function(b){return(b.textContent||b.innerText||e(b)).indexOf(a)>-1}}),lang:ib(function(a){return W.test(a||"")||gb.error("unsupported lang: "+a),a=a.replace(cb,db).toLowerCase(),function(b){var c;do if(c=p?b.lang:b.getAttribute("xml:lang")||b.getAttribute("lang"))return c=c.toLowerCase(),c===a||0===c.indexOf(a+"-");while((b=b.parentNode)&&1===b.nodeType);return!1}}),target:function(b){var c=a.location&&a.location.hash;return c&&c.slice(1)===b.id},root:function(a){return a===o},focus:function(a){return a===n.activeElement&&(!n.hasFocus||n.hasFocus())&&!!(a.type||a.href||~a.tabIndex)},enabled:function(a){return a.disabled===!1},disabled:function(a){return a.disabled===!0},checked:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&!!a.checked||"option"===b&&!!a.selected},selected:function(a){return a.parentNode&&a.parentNode.selectedIndex,a.selected===!0},empty:function(a){for(a=a.firstChild;a;a=a.nextSibling)if(a.nodeType<6)return!1;return!0},parent:function(a){return!d.pseudos.empty(a)},header:function(a){return Z.test(a.nodeName)},input:function(a){return Y.test(a.nodeName)},button:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&"button"===a.type||"button"===b},text:function(a){var b;return"input"===a.nodeName.toLowerCase()&&"text"===a.type&&(null==(b=a.getAttribute("type"))||"text"===b.toLowerCase())},first:ob(function(){return[0]}),last:ob(function(a,b){return[b-1]}),eq:ob(function(a,b,c){return[0>c?c+b:c]}),even:ob(function(a,b){for(var c=0;b>c;c+=2)a.push(c);return a}),odd:ob(function(a,b){for(var c=1;b>c;c+=2)a.push(c);return a}),lt:ob(function(a,b,c){for(var d=0>c?c+b:c;--d>=0;)a.push(d);return a}),gt:ob(function(a,b,c){for(var d=0>c?c+b:c;++d<b;)a.push(d);return a})}},d.pseudos.nth=d.pseudos.eq;for(b in{radio:!0,checkbox:!0,file:!0,password:!0,image:!0})d.pseudos[b]=mb(b);for(b in{submit:!0,reset:!0})d.pseudos[b]=nb(b);function qb(){}qb.prototype=d.filters=d.pseudos,d.setFilters=new qb,g=gb.tokenize=function(a,b){var c,e,f,g,h,i,j,k=z[a+" "];if(k)return b?0:k.slice(0);h=a,i=[],j=d.preFilter;while(h){(!c||(e=S.exec(h)))&&(e&&(h=h.slice(e[0].length)||h),i.push(f=[])),c=!1,(e=T.exec(h))&&(c=e.shift(),f.push({value:c,type:e[0].replace(R," ")}),h=h.slice(c.length));for(g in d.filter)!(e=X[g].exec(h))||j[g]&&!(e=j[g](e))||(c=e.shift(),f.push({value:c,type:g,matches:e}),h=h.slice(c.length));if(!c)break}return b?h.length:h?gb.error(a):z(a,i).slice(0)};function rb(a){for(var b=0,c=a.length,d="";c>b;b++)d+=a[b].value;return d}function sb(a,b,c){var d=b.dir,e=c&&"parentNode"===d,f=x++;return b.first?function(b,c,f){while(b=b[d])if(1===b.nodeType||e)return a(b,c,f)}:function(b,c,g){var h,i,j=[w,f];if(g){while(b=b[d])if((1===b.nodeType||e)&&a(b,c,g))return!0}else while(b=b[d])if(1===b.nodeType||e){if(i=b[u]||(b[u]={}),(h=i[d])&&h[0]===w&&h[1]===f)return j[2]=h[2];if(i[d]=j,j[2]=a(b,c,g))return!0}}}function tb(a){return a.length>1?function(b,c,d){var e=a.length;while(e--)if(!a[e](b,c,d))return!1;return!0}:a[0]}function ub(a,b,c){for(var d=0,e=b.length;e>d;d++)gb(a,b[d],c);return c}function vb(a,b,c,d,e){for(var f,g=[],h=0,i=a.length,j=null!=b;i>h;h++)(f=a[h])&&(!c||c(f,d,e))&&(g.push(f),j&&b.push(h));return g}function wb(a,b,c,d,e,f){return d&&!d[u]&&(d=wb(d)),e&&!e[u]&&(e=wb(e,f)),ib(function(f,g,h,i){var j,k,l,m=[],n=[],o=g.length,p=f||ub(b||"*",h.nodeType?[h]:h,[]),q=!a||!f&&b?p:vb(p,m,a,h,i),r=c?e||(f?a:o||d)?[]:g:q;if(c&&c(q,r,h,i),d){j=vb(r,n),d(j,[],h,i),k=j.length;while(k--)(l=j[k])&&(r[n[k]]=!(q[n[k]]=l))}if(f){if(e||a){if(e){j=[],k=r.length;while(k--)(l=r[k])&&j.push(q[k]=l);e(null,r=[],j,i)}k=r.length;while(k--)(l=r[k])&&(j=e?J(f,l):m[k])>-1&&(f[j]=!(g[j]=l))}}else r=vb(r===g?r.splice(o,r.length):r),e?e(null,g,r,i):H.apply(g,r)})}function xb(a){for(var b,c,e,f=a.length,g=d.relative[a[0].type],h=g||d.relative[" "],i=g?1:0,k=sb(function(a){return a===b},h,!0),l=sb(function(a){return J(b,a)>-1},h,!0),m=[function(a,c,d){var e=!g&&(d||c!==j)||((b=c).nodeType?k(a,c,d):l(a,c,d));return b=null,e}];f>i;i++)if(c=d.relative[a[i].type])m=[sb(tb(m),c)];else{if(c=d.filter[a[i].type].apply(null,a[i].matches),c[u]){for(e=++i;f>e;e++)if(d.relative[a[e].type])break;return wb(i>1&&tb(m),i>1&&rb(a.slice(0,i-1).concat({value:" "===a[i-2].type?"*":""})).replace(R,"$1"),c,e>i&&xb(a.slice(i,e)),f>e&&xb(a=a.slice(e)),f>e&&rb(a))}m.push(c)}return tb(m)}function yb(a,b){var c=b.length>0,e=a.length>0,f=function(f,g,h,i,k){var l,m,o,p=0,q="0",r=f&&[],s=[],t=j,u=f||e&&d.find.TAG("*",k),v=w+=null==t?1:Math.random()||.1,x=u.length;for(k&&(j=g!==n&&g);q!==x&&null!=(l=u[q]);q++){if(e&&l){m=0;while(o=a[m++])if(o(l,g,h)){i.push(l);break}k&&(w=v)}c&&((l=!o&&l)&&p--,f&&r.push(l))}if(p+=q,c&&q!==p){m=0;while(o=b[m++])o(r,s,g,h);if(f){if(p>0)while(q--)r[q]||s[q]||(s[q]=F.call(i));s=vb(s)}H.apply(i,s),k&&!f&&s.length>0&&p+b.length>1&&gb.uniqueSort(i)}return k&&(w=v,j=t),r};return c?ib(f):f}return h=gb.compile=function(a,b){var c,d=[],e=[],f=A[a+" "];if(!f){b||(b=g(a)),c=b.length;while(c--)f=xb(b[c]),f[u]?d.push(f):e.push(f);f=A(a,yb(e,d)),f.selector=a}return f},i=gb.select=function(a,b,e,f){var i,j,k,l,m,n="function"==typeof a&&a,o=!f&&g(a=n.selector||a);if(e=e||[],1===o.length){if(j=o[0]=o[0].slice(0),j.length>2&&"ID"===(k=j[0]).type&&c.getById&&9===b.nodeType&&p&&d.relative[j[1].type]){if(b=(d.find.ID(k.matches[0].replace(cb,db),b)||[])[0],!b)return e;n&&(b=b.parentNode),a=a.slice(j.shift().value.length)}i=X.needsContext.test(a)?0:j.length;while(i--){if(k=j[i],d.relative[l=k.type])break;if((m=d.find[l])&&(f=m(k.matches[0].replace(cb,db),ab.test(j[0].type)&&pb(b.parentNode)||b))){if(j.splice(i,1),a=f.length&&rb(j),!a)return H.apply(e,f),e;break}}}return(n||h(a,o))(f,b,!p,e,ab.test(a)&&pb(b.parentNode)||b),e},c.sortStable=u.split("").sort(B).join("")===u,c.detectDuplicates=!!l,m(),c.sortDetached=jb(function(a){return 1&a.compareDocumentPosition(n.createElement("div"))}),jb(function(a){return a.innerHTML="<a href='#'></a>","#"===a.firstChild.getAttribute("href")})||kb("type|href|height|width",function(a,b,c){return c?void 0:a.getAttribute(b,"type"===b.toLowerCase()?1:2)}),c.attributes&&jb(function(a){return a.innerHTML="<input/>",a.firstChild.setAttribute("value",""),""===a.firstChild.getAttribute("value")})||kb("value",function(a,b,c){return c||"input"!==a.nodeName.toLowerCase()?void 0:a.defaultValue}),jb(function(a){return null==a.getAttribute("disabled")})||kb(K,function(a,b,c){var d;return c?void 0:a[b]===!0?b.toLowerCase():(d=a.getAttributeNode(b))&&d.specified?d.value:null}),gb}(a);n.find=t,n.expr=t.selectors,n.expr[":"]=n.expr.pseudos,n.unique=t.uniqueSort,n.text=t.getText,n.isXMLDoc=t.isXML,n.contains=t.contains;var u=n.expr.match.needsContext,v=/^<(\w+)\s*\/?>(?:<\/\1>|)$/,w=/^.[^:#\[\.,]*$/;function x(a,b,c){if(n.isFunction(b))return n.grep(a,function(a,d){return!!b.call(a,d,a)!==c});if(b.nodeType)return n.grep(a,function(a){return a===b!==c});if("string"==typeof b){if(w.test(b))return n.filter(b,a,c);b=n.filter(b,a)}return n.grep(a,function(a){return g.call(b,a)>=0!==c})}n.filter=function(a,b,c){var d=b[0];return c&&(a=":not("+a+")"),1===b.length&&1===d.nodeType?n.find.matchesSelector(d,a)?[d]:[]:n.find.matches(a,n.grep(b,function(a){return 1===a.nodeType}))},n.fn.extend({find:function(a){var b,c=this.length,d=[],e=this;if("string"!=typeof a)return this.pushStack(n(a).filter(function(){for(b=0;c>b;b++)if(n.contains(e[b],this))return!0}));for(b=0;c>b;b++)n.find(a,e[b],d);return d=this.pushStack(c>1?n.unique(d):d),d.selector=this.selector?this.selector+" "+a:a,d},filter:function(a){return this.pushStack(x(this,a||[],!1))},not:function(a){return this.pushStack(x(this,a||[],!0))},is:function(a){return!!x(this,"string"==typeof a&&u.test(a)?n(a):a||[],!1).length}});var y,z=/^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]*))$/,A=n.fn.init=function(a,b){var c,d;if(!a)return this;if("string"==typeof a){if(c="<"===a[0]&&">"===a[a.length-1]&&a.length>=3?[null,a,null]:z.exec(a),!c||!c[1]&&b)return!b||b.jquery?(b||y).find(a):this.constructor(b).find(a);if(c[1]){if(b=b instanceof n?b[0]:b,n.merge(this,n.parseHTML(c[1],b&&b.nodeType?b.ownerDocument||b:l,!0)),v.test(c[1])&&n.isPlainObject(b))for(c in b)n.isFunction(this[c])?this[c](b[c]):this.attr(c,b[c]);return this}return d=l.getElementById(c[2]),d&&d.parentNode&&(this.length=1,this[0]=d),this.context=l,this.selector=a,this}return a.nodeType?(this.context=this[0]=a,this.length=1,this):n.isFunction(a)?"undefined"!=typeof y.ready?y.ready(a):a(n):(void 0!==a.selector&&(this.selector=a.selector,this.context=a.context),n.makeArray(a,this))};A.prototype=n.fn,y=n(l);var B=/^(?:parents|prev(?:Until|All))/,C={children:!0,contents:!0,next:!0,prev:!0};n.extend({dir:function(a,b,c){var d=[],e=void 0!==c;while((a=a[b])&&9!==a.nodeType)if(1===a.nodeType){if(e&&n(a).is(c))break;d.push(a)}return d},sibling:function(a,b){for(var c=[];a;a=a.nextSibling)1===a.nodeType&&a!==b&&c.push(a);return c}}),n.fn.extend({has:function(a){var b=n(a,this),c=b.length;return this.filter(function(){for(var a=0;c>a;a++)if(n.contains(this,b[a]))return!0})},closest:function(a,b){for(var c,d=0,e=this.length,f=[],g=u.test(a)||"string"!=typeof a?n(a,b||this.context):0;e>d;d++)for(c=this[d];c&&c!==b;c=c.parentNode)if(c.nodeType<11&&(g?g.index(c)>-1:1===c.nodeType&&n.find.matchesSelector(c,a))){f.push(c);break}return this.pushStack(f.length>1?n.unique(f):f)},index:function(a){return a?"string"==typeof a?g.call(n(a),this[0]):g.call(this,a.jquery?a[0]:a):this[0]&&this[0].parentNode?this.first().prevAll().length:-1},add:function(a,b){return this.pushStack(n.unique(n.merge(this.get(),n(a,b))))},addBack:function(a){return this.add(null==a?this.prevObject:this.prevObject.filter(a))}});function D(a,b){while((a=a[b])&&1!==a.nodeType);return a}n.each({parent:function(a){var b=a.parentNode;return b&&11!==b.nodeType?b:null},parents:function(a){return n.dir(a,"parentNode")},parentsUntil:function(a,b,c){return n.dir(a,"parentNode",c)},next:function(a){return D(a,"nextSibling")},prev:function(a){return D(a,"previousSibling")},nextAll:function(a){return n.dir(a,"nextSibling")},prevAll:function(a){return n.dir(a,"previousSibling")},nextUntil:function(a,b,c){return n.dir(a,"nextSibling",c)},prevUntil:function(a,b,c){return n.dir(a,"previousSibling",c)},siblings:function(a){return n.sibling((a.parentNode||{}).firstChild,a)},children:function(a){return n.sibling(a.firstChild)},contents:function(a){return a.contentDocument||n.merge([],a.childNodes)}},function(a,b){n.fn[a]=function(c,d){var e=n.map(this,b,c);return"Until"!==a.slice(-5)&&(d=c),d&&"string"==typeof d&&(e=n.filter(d,e)),this.length>1&&(C[a]||n.unique(e),B.test(a)&&e.reverse()),this.pushStack(e)}});var E=/\S+/g,F={};function G(a){var b=F[a]={};return n.each(a.match(E)||[],function(a,c){b[c]=!0}),b}n.Callbacks=function(a){a="string"==typeof a?F[a]||G(a):n.extend({},a);var b,c,d,e,f,g,h=[],i=!a.once&&[],j=function(l){for(b=a.memory&&l,c=!0,g=e||0,e=0,f=h.length,d=!0;h&&f>g;g++)if(h[g].apply(l[0],l[1])===!1&&a.stopOnFalse){b=!1;break}d=!1,h&&(i?i.length&&j(i.shift()):b?h=[]:k.disable())},k={add:function(){if(h){var c=h.length;!function g(b){n.each(b,function(b,c){var d=n.type(c);"function"===d?a.unique&&k.has(c)||h.push(c):c&&c.length&&"string"!==d&&g(c)})}(arguments),d?f=h.length:b&&(e=c,j(b))}return this},remove:function(){return h&&n.each(arguments,function(a,b){var c;while((c=n.inArray(b,h,c))>-1)h.splice(c,1),d&&(f>=c&&f--,g>=c&&g--)}),this},has:function(a){return a?n.inArray(a,h)>-1:!(!h||!h.length)},empty:function(){return h=[],f=0,this},disable:function(){return h=i=b=void 0,this},disabled:function(){return!h},lock:function(){return i=void 0,b||k.disable(),this},locked:function(){return!i},fireWith:function(a,b){return!h||c&&!i||(b=b||[],b=[a,b.slice?b.slice():b],d?i.push(b):j(b)),this},fire:function(){return k.fireWith(this,arguments),this},fired:function(){return!!c}};return k},n.extend({Deferred:function(a){var b=[["resolve","done",n.Callbacks("once memory"),"resolved"],["reject","fail",n.Callbacks("once memory"),"rejected"],["notify","progress",n.Callbacks("memory")]],c="pending",d={state:function(){return c},always:function(){return e.done(arguments).fail(arguments),this},then:function(){var a=arguments;return n.Deferred(function(c){n.each(b,function(b,f){var g=n.isFunction(a[b])&&a[b];e[f[1]](function(){var a=g&&g.apply(this,arguments);a&&n.isFunction(a.promise)?a.promise().done(c.resolve).fail(c.reject).progress(c.notify):c[f[0]+"With"](this===d?c.promise():this,g?[a]:arguments)})}),a=null}).promise()},promise:function(a){return null!=a?n.extend(a,d):d}},e={};return d.pipe=d.then,n.each(b,function(a,f){var g=f[2],h=f[3];d[f[1]]=g.add,h&&g.add(function(){c=h},b[1^a][2].disable,b[2][2].lock),e[f[0]]=function(){return e[f[0]+"With"](this===e?d:this,arguments),this},e[f[0]+"With"]=g.fireWith}),d.promise(e),a&&a.call(e,e),e},when:function(a){var b=0,c=d.call(arguments),e=c.length,f=1!==e||a&&n.isFunction(a.promise)?e:0,g=1===f?a:n.Deferred(),h=function(a,b,c){return function(e){b[a]=this,c[a]=arguments.length>1?d.call(arguments):e,c===i?g.notifyWith(b,c):--f||g.resolveWith(b,c)}},i,j,k;if(e>1)for(i=new Array(e),j=new Array(e),k=new Array(e);e>b;b++)c[b]&&n.isFunction(c[b].promise)?c[b].promise().done(h(b,k,c)).fail(g.reject).progress(h(b,j,i)):--f;return f||g.resolveWith(k,c),g.promise()}});var H;n.fn.ready=function(a){return n.ready.promise().done(a),this},n.extend({isReady:!1,readyWait:1,holdReady:function(a){a?n.readyWait++:n.ready(!0)},ready:function(a){(a===!0?--n.readyWait:n.isReady)||(n.isReady=!0,a!==!0&&--n.readyWait>0||(H.resolveWith(l,[n]),n.fn.triggerHandler&&(n(l).triggerHandler("ready"),n(l).off("ready"))))}});function I(){l.removeEventListener("DOMContentLoaded",I,!1),a.removeEventListener("load",I,!1),n.ready()}n.ready.promise=function(b){return H||(H=n.Deferred(),"complete"===l.readyState?setTimeout(n.ready):(l.addEventListener("DOMContentLoaded",I,!1),a.addEventListener("load",I,!1))),H.promise(b)},n.ready.promise();var J=n.access=function(a,b,c,d,e,f,g){var h=0,i=a.length,j=null==c;if("object"===n.type(c)){e=!0;for(h in c)n.access(a,b,h,c[h],!0,f,g)}else if(void 0!==d&&(e=!0,n.isFunction(d)||(g=!0),j&&(g?(b.call(a,d),b=null):(j=b,b=function(a,b,c){return j.call(n(a),c)})),b))for(;i>h;h++)b(a[h],c,g?d:d.call(a[h],h,b(a[h],c)));return e?a:j?b.call(a):i?b(a[0],c):f};n.acceptData=function(a){return 1===a.nodeType||9===a.nodeType||!+a.nodeType};function K(){Object.defineProperty(this.cache={},0,{get:function(){return{}}}),this.expando=n.expando+K.uid++}K.uid=1,K.accepts=n.acceptData,K.prototype={key:function(a){if(!K.accepts(a))return 0;var b={},c=a[this.expando];if(!c){c=K.uid++;try{b[this.expando]={value:c},Object.defineProperties(a,b)}catch(d){b[this.expando]=c,n.extend(a,b)}}return this.cache[c]||(this.cache[c]={}),c},set:function(a,b,c){var d,e=this.key(a),f=this.cache[e];if("string"==typeof b)f[b]=c;else if(n.isEmptyObject(f))n.extend(this.cache[e],b);else for(d in b)f[d]=b[d];return f},get:function(a,b){var c=this.cache[this.key(a)];return void 0===b?c:c[b]},access:function(a,b,c){var d;return void 0===b||b&&"string"==typeof b&&void 0===c?(d=this.get(a,b),void 0!==d?d:this.get(a,n.camelCase(b))):(this.set(a,b,c),void 0!==c?c:b)},remove:function(a,b){var c,d,e,f=this.key(a),g=this.cache[f];if(void 0===b)this.cache[f]={};else{n.isArray(b)?d=b.concat(b.map(n.camelCase)):(e=n.camelCase(b),b in g?d=[b,e]:(d=e,d=d in g?[d]:d.match(E)||[])),c=d.length;while(c--)delete g[d[c]]}},hasData:function(a){return!n.isEmptyObject(this.cache[a[this.expando]]||{})},discard:function(a){a[this.expando]&&delete this.cache[a[this.expando]]}};var L=new K,M=new K,N=/^(?:\{[\w\W]*\}|\[[\w\W]*\])$/,O=/([A-Z])/g;function P(a,b,c){var d;if(void 0===c&&1===a.nodeType)if(d="data-"+b.replace(O,"-$1").toLowerCase(),c=a.getAttribute(d),"string"==typeof c){try{c="true"===c?!0:"false"===c?!1:"null"===c?null:+c+""===c?+c:N.test(c)?n.parseJSON(c):c}catch(e){}M.set(a,b,c)}else c=void 0;return c}n.extend({hasData:function(a){return M.hasData(a)||L.hasData(a)},data:function(a,b,c){return M.access(a,b,c) -},removeData:function(a,b){M.remove(a,b)},_data:function(a,b,c){return L.access(a,b,c)},_removeData:function(a,b){L.remove(a,b)}}),n.fn.extend({data:function(a,b){var c,d,e,f=this[0],g=f&&f.attributes;if(void 0===a){if(this.length&&(e=M.get(f),1===f.nodeType&&!L.get(f,"hasDataAttrs"))){c=g.length;while(c--)g[c]&&(d=g[c].name,0===d.indexOf("data-")&&(d=n.camelCase(d.slice(5)),P(f,d,e[d])));L.set(f,"hasDataAttrs",!0)}return e}return"object"==typeof a?this.each(function(){M.set(this,a)}):J(this,function(b){var c,d=n.camelCase(a);if(f&&void 0===b){if(c=M.get(f,a),void 0!==c)return c;if(c=M.get(f,d),void 0!==c)return c;if(c=P(f,d,void 0),void 0!==c)return c}else this.each(function(){var c=M.get(this,d);M.set(this,d,b),-1!==a.indexOf("-")&&void 0!==c&&M.set(this,a,b)})},null,b,arguments.length>1,null,!0)},removeData:function(a){return this.each(function(){M.remove(this,a)})}}),n.extend({queue:function(a,b,c){var d;return a?(b=(b||"fx")+"queue",d=L.get(a,b),c&&(!d||n.isArray(c)?d=L.access(a,b,n.makeArray(c)):d.push(c)),d||[]):void 0},dequeue:function(a,b){b=b||"fx";var c=n.queue(a,b),d=c.length,e=c.shift(),f=n._queueHooks(a,b),g=function(){n.dequeue(a,b)};"inprogress"===e&&(e=c.shift(),d--),e&&("fx"===b&&c.unshift("inprogress"),delete f.stop,e.call(a,g,f)),!d&&f&&f.empty.fire()},_queueHooks:function(a,b){var c=b+"queueHooks";return L.get(a,c)||L.access(a,c,{empty:n.Callbacks("once memory").add(function(){L.remove(a,[b+"queue",c])})})}}),n.fn.extend({queue:function(a,b){var c=2;return"string"!=typeof a&&(b=a,a="fx",c--),arguments.length<c?n.queue(this[0],a):void 0===b?this:this.each(function(){var c=n.queue(this,a,b);n._queueHooks(this,a),"fx"===a&&"inprogress"!==c[0]&&n.dequeue(this,a)})},dequeue:function(a){return this.each(function(){n.dequeue(this,a)})},clearQueue:function(a){return this.queue(a||"fx",[])},promise:function(a,b){var c,d=1,e=n.Deferred(),f=this,g=this.length,h=function(){--d||e.resolveWith(f,[f])};"string"!=typeof a&&(b=a,a=void 0),a=a||"fx";while(g--)c=L.get(f[g],a+"queueHooks"),c&&c.empty&&(d++,c.empty.add(h));return h(),e.promise(b)}});var Q=/[+-]?(?:\d*\.|)\d+(?:[eE][+-]?\d+|)/.source,R=["Top","Right","Bottom","Left"],S=function(a,b){return a=b||a,"none"===n.css(a,"display")||!n.contains(a.ownerDocument,a)},T=/^(?:checkbox|radio)$/i;!function(){var a=l.createDocumentFragment(),b=a.appendChild(l.createElement("div")),c=l.createElement("input");c.setAttribute("type","radio"),c.setAttribute("checked","checked"),c.setAttribute("name","t"),b.appendChild(c),k.checkClone=b.cloneNode(!0).cloneNode(!0).lastChild.checked,b.innerHTML="<textarea>x</textarea>",k.noCloneChecked=!!b.cloneNode(!0).lastChild.defaultValue}();var U="undefined";k.focusinBubbles="onfocusin"in a;var V=/^key/,W=/^(?:mouse|pointer|contextmenu)|click/,X=/^(?:focusinfocus|focusoutblur)$/,Y=/^([^.]*)(?:\.(.+)|)$/;function Z(){return!0}function $(){return!1}function _(){try{return l.activeElement}catch(a){}}n.event={global:{},add:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=L.get(a);if(r){c.handler&&(f=c,c=f.handler,e=f.selector),c.guid||(c.guid=n.guid++),(i=r.events)||(i=r.events={}),(g=r.handle)||(g=r.handle=function(b){return typeof n!==U&&n.event.triggered!==b.type?n.event.dispatch.apply(a,arguments):void 0}),b=(b||"").match(E)||[""],j=b.length;while(j--)h=Y.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o&&(l=n.event.special[o]||{},o=(e?l.delegateType:l.bindType)||o,l=n.event.special[o]||{},k=n.extend({type:o,origType:q,data:d,handler:c,guid:c.guid,selector:e,needsContext:e&&n.expr.match.needsContext.test(e),namespace:p.join(".")},f),(m=i[o])||(m=i[o]=[],m.delegateCount=0,l.setup&&l.setup.call(a,d,p,g)!==!1||a.addEventListener&&a.addEventListener(o,g,!1)),l.add&&(l.add.call(a,k),k.handler.guid||(k.handler.guid=c.guid)),e?m.splice(m.delegateCount++,0,k):m.push(k),n.event.global[o]=!0)}},remove:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=L.hasData(a)&&L.get(a);if(r&&(i=r.events)){b=(b||"").match(E)||[""],j=b.length;while(j--)if(h=Y.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o){l=n.event.special[o]||{},o=(d?l.delegateType:l.bindType)||o,m=i[o]||[],h=h[2]&&new RegExp("(^|\\.)"+p.join("\\.(?:.*\\.|)")+"(\\.|$)"),g=f=m.length;while(f--)k=m[f],!e&&q!==k.origType||c&&c.guid!==k.guid||h&&!h.test(k.namespace)||d&&d!==k.selector&&("**"!==d||!k.selector)||(m.splice(f,1),k.selector&&m.delegateCount--,l.remove&&l.remove.call(a,k));g&&!m.length&&(l.teardown&&l.teardown.call(a,p,r.handle)!==!1||n.removeEvent(a,o,r.handle),delete i[o])}else for(o in i)n.event.remove(a,o+b[j],c,d,!0);n.isEmptyObject(i)&&(delete r.handle,L.remove(a,"events"))}},trigger:function(b,c,d,e){var f,g,h,i,k,m,o,p=[d||l],q=j.call(b,"type")?b.type:b,r=j.call(b,"namespace")?b.namespace.split("."):[];if(g=h=d=d||l,3!==d.nodeType&&8!==d.nodeType&&!X.test(q+n.event.triggered)&&(q.indexOf(".")>=0&&(r=q.split("."),q=r.shift(),r.sort()),k=q.indexOf(":")<0&&"on"+q,b=b[n.expando]?b:new n.Event(q,"object"==typeof b&&b),b.isTrigger=e?2:3,b.namespace=r.join("."),b.namespace_re=b.namespace?new RegExp("(^|\\.)"+r.join("\\.(?:.*\\.|)")+"(\\.|$)"):null,b.result=void 0,b.target||(b.target=d),c=null==c?[b]:n.makeArray(c,[b]),o=n.event.special[q]||{},e||!o.trigger||o.trigger.apply(d,c)!==!1)){if(!e&&!o.noBubble&&!n.isWindow(d)){for(i=o.delegateType||q,X.test(i+q)||(g=g.parentNode);g;g=g.parentNode)p.push(g),h=g;h===(d.ownerDocument||l)&&p.push(h.defaultView||h.parentWindow||a)}f=0;while((g=p[f++])&&!b.isPropagationStopped())b.type=f>1?i:o.bindType||q,m=(L.get(g,"events")||{})[b.type]&&L.get(g,"handle"),m&&m.apply(g,c),m=k&&g[k],m&&m.apply&&n.acceptData(g)&&(b.result=m.apply(g,c),b.result===!1&&b.preventDefault());return b.type=q,e||b.isDefaultPrevented()||o._default&&o._default.apply(p.pop(),c)!==!1||!n.acceptData(d)||k&&n.isFunction(d[q])&&!n.isWindow(d)&&(h=d[k],h&&(d[k]=null),n.event.triggered=q,d[q](),n.event.triggered=void 0,h&&(d[k]=h)),b.result}},dispatch:function(a){a=n.event.fix(a);var b,c,e,f,g,h=[],i=d.call(arguments),j=(L.get(this,"events")||{})[a.type]||[],k=n.event.special[a.type]||{};if(i[0]=a,a.delegateTarget=this,!k.preDispatch||k.preDispatch.call(this,a)!==!1){h=n.event.handlers.call(this,a,j),b=0;while((f=h[b++])&&!a.isPropagationStopped()){a.currentTarget=f.elem,c=0;while((g=f.handlers[c++])&&!a.isImmediatePropagationStopped())(!a.namespace_re||a.namespace_re.test(g.namespace))&&(a.handleObj=g,a.data=g.data,e=((n.event.special[g.origType]||{}).handle||g.handler).apply(f.elem,i),void 0!==e&&(a.result=e)===!1&&(a.preventDefault(),a.stopPropagation()))}return k.postDispatch&&k.postDispatch.call(this,a),a.result}},handlers:function(a,b){var c,d,e,f,g=[],h=b.delegateCount,i=a.target;if(h&&i.nodeType&&(!a.button||"click"!==a.type))for(;i!==this;i=i.parentNode||this)if(i.disabled!==!0||"click"!==a.type){for(d=[],c=0;h>c;c++)f=b[c],e=f.selector+" ",void 0===d[e]&&(d[e]=f.needsContext?n(e,this).index(i)>=0:n.find(e,this,null,[i]).length),d[e]&&d.push(f);d.length&&g.push({elem:i,handlers:d})}return h<b.length&&g.push({elem:this,handlers:b.slice(h)}),g},props:"altKey bubbles cancelable ctrlKey currentTarget eventPhase metaKey relatedTarget shiftKey target timeStamp view which".split(" "),fixHooks:{},keyHooks:{props:"char charCode key keyCode".split(" "),filter:function(a,b){return null==a.which&&(a.which=null!=b.charCode?b.charCode:b.keyCode),a}},mouseHooks:{props:"button buttons clientX clientY offsetX offsetY pageX pageY screenX screenY toElement".split(" "),filter:function(a,b){var c,d,e,f=b.button;return null==a.pageX&&null!=b.clientX&&(c=a.target.ownerDocument||l,d=c.documentElement,e=c.body,a.pageX=b.clientX+(d&&d.scrollLeft||e&&e.scrollLeft||0)-(d&&d.clientLeft||e&&e.clientLeft||0),a.pageY=b.clientY+(d&&d.scrollTop||e&&e.scrollTop||0)-(d&&d.clientTop||e&&e.clientTop||0)),a.which||void 0===f||(a.which=1&f?1:2&f?3:4&f?2:0),a}},fix:function(a){if(a[n.expando])return a;var b,c,d,e=a.type,f=a,g=this.fixHooks[e];g||(this.fixHooks[e]=g=W.test(e)?this.mouseHooks:V.test(e)?this.keyHooks:{}),d=g.props?this.props.concat(g.props):this.props,a=new n.Event(f),b=d.length;while(b--)c=d[b],a[c]=f[c];return a.target||(a.target=l),3===a.target.nodeType&&(a.target=a.target.parentNode),g.filter?g.filter(a,f):a},special:{load:{noBubble:!0},focus:{trigger:function(){return this!==_()&&this.focus?(this.focus(),!1):void 0},delegateType:"focusin"},blur:{trigger:function(){return this===_()&&this.blur?(this.blur(),!1):void 0},delegateType:"focusout"},click:{trigger:function(){return"checkbox"===this.type&&this.click&&n.nodeName(this,"input")?(this.click(),!1):void 0},_default:function(a){return n.nodeName(a.target,"a")}},beforeunload:{postDispatch:function(a){void 0!==a.result&&a.originalEvent&&(a.originalEvent.returnValue=a.result)}}},simulate:function(a,b,c,d){var e=n.extend(new n.Event,c,{type:a,isSimulated:!0,originalEvent:{}});d?n.event.trigger(e,null,b):n.event.dispatch.call(b,e),e.isDefaultPrevented()&&c.preventDefault()}},n.removeEvent=function(a,b,c){a.removeEventListener&&a.removeEventListener(b,c,!1)},n.Event=function(a,b){return this instanceof n.Event?(a&&a.type?(this.originalEvent=a,this.type=a.type,this.isDefaultPrevented=a.defaultPrevented||void 0===a.defaultPrevented&&a.returnValue===!1?Z:$):this.type=a,b&&n.extend(this,b),this.timeStamp=a&&a.timeStamp||n.now(),void(this[n.expando]=!0)):new n.Event(a,b)},n.Event.prototype={isDefaultPrevented:$,isPropagationStopped:$,isImmediatePropagationStopped:$,preventDefault:function(){var a=this.originalEvent;this.isDefaultPrevented=Z,a&&a.preventDefault&&a.preventDefault()},stopPropagation:function(){var a=this.originalEvent;this.isPropagationStopped=Z,a&&a.stopPropagation&&a.stopPropagation()},stopImmediatePropagation:function(){var a=this.originalEvent;this.isImmediatePropagationStopped=Z,a&&a.stopImmediatePropagation&&a.stopImmediatePropagation(),this.stopPropagation()}},n.each({mouseenter:"mouseover",mouseleave:"mouseout",pointerenter:"pointerover",pointerleave:"pointerout"},function(a,b){n.event.special[a]={delegateType:b,bindType:b,handle:function(a){var c,d=this,e=a.relatedTarget,f=a.handleObj;return(!e||e!==d&&!n.contains(d,e))&&(a.type=f.origType,c=f.handler.apply(this,arguments),a.type=b),c}}}),k.focusinBubbles||n.each({focus:"focusin",blur:"focusout"},function(a,b){var c=function(a){n.event.simulate(b,a.target,n.event.fix(a),!0)};n.event.special[b]={setup:function(){var d=this.ownerDocument||this,e=L.access(d,b);e||d.addEventListener(a,c,!0),L.access(d,b,(e||0)+1)},teardown:function(){var d=this.ownerDocument||this,e=L.access(d,b)-1;e?L.access(d,b,e):(d.removeEventListener(a,c,!0),L.remove(d,b))}}}),n.fn.extend({on:function(a,b,c,d,e){var f,g;if("object"==typeof a){"string"!=typeof b&&(c=c||b,b=void 0);for(g in a)this.on(g,b,c,a[g],e);return this}if(null==c&&null==d?(d=b,c=b=void 0):null==d&&("string"==typeof b?(d=c,c=void 0):(d=c,c=b,b=void 0)),d===!1)d=$;else if(!d)return this;return 1===e&&(f=d,d=function(a){return n().off(a),f.apply(this,arguments)},d.guid=f.guid||(f.guid=n.guid++)),this.each(function(){n.event.add(this,a,d,c,b)})},one:function(a,b,c,d){return this.on(a,b,c,d,1)},off:function(a,b,c){var d,e;if(a&&a.preventDefault&&a.handleObj)return d=a.handleObj,n(a.delegateTarget).off(d.namespace?d.origType+"."+d.namespace:d.origType,d.selector,d.handler),this;if("object"==typeof a){for(e in a)this.off(e,b,a[e]);return this}return(b===!1||"function"==typeof b)&&(c=b,b=void 0),c===!1&&(c=$),this.each(function(){n.event.remove(this,a,c,b)})},trigger:function(a,b){return this.each(function(){n.event.trigger(a,b,this)})},triggerHandler:function(a,b){var c=this[0];return c?n.event.trigger(a,b,c,!0):void 0}});var ab=/<(?!area|br|col|embed|hr|img|input|link|meta|param)(([\w:]+)[^>]*)\/>/gi,bb=/<([\w:]+)/,cb=/<|&#?\w+;/,db=/<(?:script|style|link)/i,eb=/checked\s*(?:[^=]|=\s*.checked.)/i,fb=/^$|\/(?:java|ecma)script/i,gb=/^true\/(.*)/,hb=/^\s*<!(?:\[CDATA\[|--)|(?:\]\]|--)>\s*$/g,ib={option:[1,"<select multiple='multiple'>","</select>"],thead:[1,"<table>","</table>"],col:[2,"<table><colgroup>","</colgroup></table>"],tr:[2,"<table><tbody>","</tbody></table>"],td:[3,"<table><tbody><tr>","</tr></tbody></table>"],_default:[0,"",""]};ib.optgroup=ib.option,ib.tbody=ib.tfoot=ib.colgroup=ib.caption=ib.thead,ib.th=ib.td;function jb(a,b){return n.nodeName(a,"table")&&n.nodeName(11!==b.nodeType?b:b.firstChild,"tr")?a.getElementsByTagName("tbody")[0]||a.appendChild(a.ownerDocument.createElement("tbody")):a}function kb(a){return a.type=(null!==a.getAttribute("type"))+"/"+a.type,a}function lb(a){var b=gb.exec(a.type);return b?a.type=b[1]:a.removeAttribute("type"),a}function mb(a,b){for(var c=0,d=a.length;d>c;c++)L.set(a[c],"globalEval",!b||L.get(b[c],"globalEval"))}function nb(a,b){var c,d,e,f,g,h,i,j;if(1===b.nodeType){if(L.hasData(a)&&(f=L.access(a),g=L.set(b,f),j=f.events)){delete g.handle,g.events={};for(e in j)for(c=0,d=j[e].length;d>c;c++)n.event.add(b,e,j[e][c])}M.hasData(a)&&(h=M.access(a),i=n.extend({},h),M.set(b,i))}}function ob(a,b){var c=a.getElementsByTagName?a.getElementsByTagName(b||"*"):a.querySelectorAll?a.querySelectorAll(b||"*"):[];return void 0===b||b&&n.nodeName(a,b)?n.merge([a],c):c}function pb(a,b){var c=b.nodeName.toLowerCase();"input"===c&&T.test(a.type)?b.checked=a.checked:("input"===c||"textarea"===c)&&(b.defaultValue=a.defaultValue)}n.extend({clone:function(a,b,c){var d,e,f,g,h=a.cloneNode(!0),i=n.contains(a.ownerDocument,a);if(!(k.noCloneChecked||1!==a.nodeType&&11!==a.nodeType||n.isXMLDoc(a)))for(g=ob(h),f=ob(a),d=0,e=f.length;e>d;d++)pb(f[d],g[d]);if(b)if(c)for(f=f||ob(a),g=g||ob(h),d=0,e=f.length;e>d;d++)nb(f[d],g[d]);else nb(a,h);return g=ob(h,"script"),g.length>0&&mb(g,!i&&ob(a,"script")),h},buildFragment:function(a,b,c,d){for(var e,f,g,h,i,j,k=b.createDocumentFragment(),l=[],m=0,o=a.length;o>m;m++)if(e=a[m],e||0===e)if("object"===n.type(e))n.merge(l,e.nodeType?[e]:e);else if(cb.test(e)){f=f||k.appendChild(b.createElement("div")),g=(bb.exec(e)||["",""])[1].toLowerCase(),h=ib[g]||ib._default,f.innerHTML=h[1]+e.replace(ab,"<$1></$2>")+h[2],j=h[0];while(j--)f=f.lastChild;n.merge(l,f.childNodes),f=k.firstChild,f.textContent=""}else l.push(b.createTextNode(e));k.textContent="",m=0;while(e=l[m++])if((!d||-1===n.inArray(e,d))&&(i=n.contains(e.ownerDocument,e),f=ob(k.appendChild(e),"script"),i&&mb(f),c)){j=0;while(e=f[j++])fb.test(e.type||"")&&c.push(e)}return k},cleanData:function(a){for(var b,c,d,e,f=n.event.special,g=0;void 0!==(c=a[g]);g++){if(n.acceptData(c)&&(e=c[L.expando],e&&(b=L.cache[e]))){if(b.events)for(d in b.events)f[d]?n.event.remove(c,d):n.removeEvent(c,d,b.handle);L.cache[e]&&delete L.cache[e]}delete M.cache[c[M.expando]]}}}),n.fn.extend({text:function(a){return J(this,function(a){return void 0===a?n.text(this):this.empty().each(function(){(1===this.nodeType||11===this.nodeType||9===this.nodeType)&&(this.textContent=a)})},null,a,arguments.length)},append:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=jb(this,a);b.appendChild(a)}})},prepend:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=jb(this,a);b.insertBefore(a,b.firstChild)}})},before:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this)})},after:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this.nextSibling)})},remove:function(a,b){for(var c,d=a?n.filter(a,this):this,e=0;null!=(c=d[e]);e++)b||1!==c.nodeType||n.cleanData(ob(c)),c.parentNode&&(b&&n.contains(c.ownerDocument,c)&&mb(ob(c,"script")),c.parentNode.removeChild(c));return this},empty:function(){for(var a,b=0;null!=(a=this[b]);b++)1===a.nodeType&&(n.cleanData(ob(a,!1)),a.textContent="");return this},clone:function(a,b){return a=null==a?!1:a,b=null==b?a:b,this.map(function(){return n.clone(this,a,b)})},html:function(a){return J(this,function(a){var b=this[0]||{},c=0,d=this.length;if(void 0===a&&1===b.nodeType)return b.innerHTML;if("string"==typeof a&&!db.test(a)&&!ib[(bb.exec(a)||["",""])[1].toLowerCase()]){a=a.replace(ab,"<$1></$2>");try{for(;d>c;c++)b=this[c]||{},1===b.nodeType&&(n.cleanData(ob(b,!1)),b.innerHTML=a);b=0}catch(e){}}b&&this.empty().append(a)},null,a,arguments.length)},replaceWith:function(){var a=arguments[0];return this.domManip(arguments,function(b){a=this.parentNode,n.cleanData(ob(this)),a&&a.replaceChild(b,this)}),a&&(a.length||a.nodeType)?this:this.remove()},detach:function(a){return this.remove(a,!0)},domManip:function(a,b){a=e.apply([],a);var c,d,f,g,h,i,j=0,l=this.length,m=this,o=l-1,p=a[0],q=n.isFunction(p);if(q||l>1&&"string"==typeof p&&!k.checkClone&&eb.test(p))return this.each(function(c){var d=m.eq(c);q&&(a[0]=p.call(this,c,d.html())),d.domManip(a,b)});if(l&&(c=n.buildFragment(a,this[0].ownerDocument,!1,this),d=c.firstChild,1===c.childNodes.length&&(c=d),d)){for(f=n.map(ob(c,"script"),kb),g=f.length;l>j;j++)h=c,j!==o&&(h=n.clone(h,!0,!0),g&&n.merge(f,ob(h,"script"))),b.call(this[j],h,j);if(g)for(i=f[f.length-1].ownerDocument,n.map(f,lb),j=0;g>j;j++)h=f[j],fb.test(h.type||"")&&!L.access(h,"globalEval")&&n.contains(i,h)&&(h.src?n._evalUrl&&n._evalUrl(h.src):n.globalEval(h.textContent.replace(hb,"")))}return this}}),n.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(a,b){n.fn[a]=function(a){for(var c,d=[],e=n(a),g=e.length-1,h=0;g>=h;h++)c=h===g?this:this.clone(!0),n(e[h])[b](c),f.apply(d,c.get());return this.pushStack(d)}});var qb,rb={};function sb(b,c){var d,e=n(c.createElement(b)).appendTo(c.body),f=a.getDefaultComputedStyle&&(d=a.getDefaultComputedStyle(e[0]))?d.display:n.css(e[0],"display");return e.detach(),f}function tb(a){var b=l,c=rb[a];return c||(c=sb(a,b),"none"!==c&&c||(qb=(qb||n("<iframe frameborder='0' width='0' height='0'/>")).appendTo(b.documentElement),b=qb[0].contentDocument,b.write(),b.close(),c=sb(a,b),qb.detach()),rb[a]=c),c}var ub=/^margin/,vb=new RegExp("^("+Q+")(?!px)[a-z%]+$","i"),wb=function(b){return b.ownerDocument.defaultView.opener?b.ownerDocument.defaultView.getComputedStyle(b,null):a.getComputedStyle(b,null)};function xb(a,b,c){var d,e,f,g,h=a.style;return c=c||wb(a),c&&(g=c.getPropertyValue(b)||c[b]),c&&(""!==g||n.contains(a.ownerDocument,a)||(g=n.style(a,b)),vb.test(g)&&ub.test(b)&&(d=h.width,e=h.minWidth,f=h.maxWidth,h.minWidth=h.maxWidth=h.width=g,g=c.width,h.width=d,h.minWidth=e,h.maxWidth=f)),void 0!==g?g+"":g}function yb(a,b){return{get:function(){return a()?void delete this.get:(this.get=b).apply(this,arguments)}}}!function(){var b,c,d=l.documentElement,e=l.createElement("div"),f=l.createElement("div");if(f.style){f.style.backgroundClip="content-box",f.cloneNode(!0).style.backgroundClip="",k.clearCloneStyle="content-box"===f.style.backgroundClip,e.style.cssText="border:0;width:0;height:0;top:0;left:-9999px;margin-top:1px;position:absolute",e.appendChild(f);function g(){f.style.cssText="-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;display:block;margin-top:1%;top:1%;border:1px;padding:1px;width:4px;position:absolute",f.innerHTML="",d.appendChild(e);var g=a.getComputedStyle(f,null);b="1%"!==g.top,c="4px"===g.width,d.removeChild(e)}a.getComputedStyle&&n.extend(k,{pixelPosition:function(){return g(),b},boxSizingReliable:function(){return null==c&&g(),c},reliableMarginRight:function(){var b,c=f.appendChild(l.createElement("div"));return c.style.cssText=f.style.cssText="-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box;display:block;margin:0;border:0;padding:0",c.style.marginRight=c.style.width="0",f.style.width="1px",d.appendChild(e),b=!parseFloat(a.getComputedStyle(c,null).marginRight),d.removeChild(e),f.removeChild(c),b}})}}(),n.swap=function(a,b,c,d){var e,f,g={};for(f in b)g[f]=a.style[f],a.style[f]=b[f];e=c.apply(a,d||[]);for(f in b)a.style[f]=g[f];return e};var zb=/^(none|table(?!-c[ea]).+)/,Ab=new RegExp("^("+Q+")(.*)$","i"),Bb=new RegExp("^([+-])=("+Q+")","i"),Cb={position:"absolute",visibility:"hidden",display:"block"},Db={letterSpacing:"0",fontWeight:"400"},Eb=["Webkit","O","Moz","ms"];function Fb(a,b){if(b in a)return b;var c=b[0].toUpperCase()+b.slice(1),d=b,e=Eb.length;while(e--)if(b=Eb[e]+c,b in a)return b;return d}function Gb(a,b,c){var d=Ab.exec(b);return d?Math.max(0,d[1]-(c||0))+(d[2]||"px"):b}function Hb(a,b,c,d,e){for(var f=c===(d?"border":"content")?4:"width"===b?1:0,g=0;4>f;f+=2)"margin"===c&&(g+=n.css(a,c+R[f],!0,e)),d?("content"===c&&(g-=n.css(a,"padding"+R[f],!0,e)),"margin"!==c&&(g-=n.css(a,"border"+R[f]+"Width",!0,e))):(g+=n.css(a,"padding"+R[f],!0,e),"padding"!==c&&(g+=n.css(a,"border"+R[f]+"Width",!0,e)));return g}function Ib(a,b,c){var d=!0,e="width"===b?a.offsetWidth:a.offsetHeight,f=wb(a),g="border-box"===n.css(a,"boxSizing",!1,f);if(0>=e||null==e){if(e=xb(a,b,f),(0>e||null==e)&&(e=a.style[b]),vb.test(e))return e;d=g&&(k.boxSizingReliable()||e===a.style[b]),e=parseFloat(e)||0}return e+Hb(a,b,c||(g?"border":"content"),d,f)+"px"}function Jb(a,b){for(var c,d,e,f=[],g=0,h=a.length;h>g;g++)d=a[g],d.style&&(f[g]=L.get(d,"olddisplay"),c=d.style.display,b?(f[g]||"none"!==c||(d.style.display=""),""===d.style.display&&S(d)&&(f[g]=L.access(d,"olddisplay",tb(d.nodeName)))):(e=S(d),"none"===c&&e||L.set(d,"olddisplay",e?c:n.css(d,"display"))));for(g=0;h>g;g++)d=a[g],d.style&&(b&&"none"!==d.style.display&&""!==d.style.display||(d.style.display=b?f[g]||"":"none"));return a}n.extend({cssHooks:{opacity:{get:function(a,b){if(b){var c=xb(a,"opacity");return""===c?"1":c}}}},cssNumber:{columnCount:!0,fillOpacity:!0,flexGrow:!0,flexShrink:!0,fontWeight:!0,lineHeight:!0,opacity:!0,order:!0,orphans:!0,widows:!0,zIndex:!0,zoom:!0},cssProps:{"float":"cssFloat"},style:function(a,b,c,d){if(a&&3!==a.nodeType&&8!==a.nodeType&&a.style){var e,f,g,h=n.camelCase(b),i=a.style;return b=n.cssProps[h]||(n.cssProps[h]=Fb(i,h)),g=n.cssHooks[b]||n.cssHooks[h],void 0===c?g&&"get"in g&&void 0!==(e=g.get(a,!1,d))?e:i[b]:(f=typeof c,"string"===f&&(e=Bb.exec(c))&&(c=(e[1]+1)*e[2]+parseFloat(n.css(a,b)),f="number"),null!=c&&c===c&&("number"!==f||n.cssNumber[h]||(c+="px"),k.clearCloneStyle||""!==c||0!==b.indexOf("background")||(i[b]="inherit"),g&&"set"in g&&void 0===(c=g.set(a,c,d))||(i[b]=c)),void 0)}},css:function(a,b,c,d){var e,f,g,h=n.camelCase(b);return b=n.cssProps[h]||(n.cssProps[h]=Fb(a.style,h)),g=n.cssHooks[b]||n.cssHooks[h],g&&"get"in g&&(e=g.get(a,!0,c)),void 0===e&&(e=xb(a,b,d)),"normal"===e&&b in Db&&(e=Db[b]),""===c||c?(f=parseFloat(e),c===!0||n.isNumeric(f)?f||0:e):e}}),n.each(["height","width"],function(a,b){n.cssHooks[b]={get:function(a,c,d){return c?zb.test(n.css(a,"display"))&&0===a.offsetWidth?n.swap(a,Cb,function(){return Ib(a,b,d)}):Ib(a,b,d):void 0},set:function(a,c,d){var e=d&&wb(a);return Gb(a,c,d?Hb(a,b,d,"border-box"===n.css(a,"boxSizing",!1,e),e):0)}}}),n.cssHooks.marginRight=yb(k.reliableMarginRight,function(a,b){return b?n.swap(a,{display:"inline-block"},xb,[a,"marginRight"]):void 0}),n.each({margin:"",padding:"",border:"Width"},function(a,b){n.cssHooks[a+b]={expand:function(c){for(var d=0,e={},f="string"==typeof c?c.split(" "):[c];4>d;d++)e[a+R[d]+b]=f[d]||f[d-2]||f[0];return e}},ub.test(a)||(n.cssHooks[a+b].set=Gb)}),n.fn.extend({css:function(a,b){return J(this,function(a,b,c){var d,e,f={},g=0;if(n.isArray(b)){for(d=wb(a),e=b.length;e>g;g++)f[b[g]]=n.css(a,b[g],!1,d);return f}return void 0!==c?n.style(a,b,c):n.css(a,b)},a,b,arguments.length>1)},show:function(){return Jb(this,!0)},hide:function(){return Jb(this)},toggle:function(a){return"boolean"==typeof a?a?this.show():this.hide():this.each(function(){S(this)?n(this).show():n(this).hide()})}});function Kb(a,b,c,d,e){return new Kb.prototype.init(a,b,c,d,e)}n.Tween=Kb,Kb.prototype={constructor:Kb,init:function(a,b,c,d,e,f){this.elem=a,this.prop=c,this.easing=e||"swing",this.options=b,this.start=this.now=this.cur(),this.end=d,this.unit=f||(n.cssNumber[c]?"":"px")},cur:function(){var a=Kb.propHooks[this.prop];return a&&a.get?a.get(this):Kb.propHooks._default.get(this)},run:function(a){var b,c=Kb.propHooks[this.prop];return this.pos=b=this.options.duration?n.easing[this.easing](a,this.options.duration*a,0,1,this.options.duration):a,this.now=(this.end-this.start)*b+this.start,this.options.step&&this.options.step.call(this.elem,this.now,this),c&&c.set?c.set(this):Kb.propHooks._default.set(this),this}},Kb.prototype.init.prototype=Kb.prototype,Kb.propHooks={_default:{get:function(a){var b;return null==a.elem[a.prop]||a.elem.style&&null!=a.elem.style[a.prop]?(b=n.css(a.elem,a.prop,""),b&&"auto"!==b?b:0):a.elem[a.prop]},set:function(a){n.fx.step[a.prop]?n.fx.step[a.prop](a):a.elem.style&&(null!=a.elem.style[n.cssProps[a.prop]]||n.cssHooks[a.prop])?n.style(a.elem,a.prop,a.now+a.unit):a.elem[a.prop]=a.now}}},Kb.propHooks.scrollTop=Kb.propHooks.scrollLeft={set:function(a){a.elem.nodeType&&a.elem.parentNode&&(a.elem[a.prop]=a.now)}},n.easing={linear:function(a){return a},swing:function(a){return.5-Math.cos(a*Math.PI)/2}},n.fx=Kb.prototype.init,n.fx.step={};var Lb,Mb,Nb=/^(?:toggle|show|hide)$/,Ob=new RegExp("^(?:([+-])=|)("+Q+")([a-z%]*)$","i"),Pb=/queueHooks$/,Qb=[Vb],Rb={"*":[function(a,b){var c=this.createTween(a,b),d=c.cur(),e=Ob.exec(b),f=e&&e[3]||(n.cssNumber[a]?"":"px"),g=(n.cssNumber[a]||"px"!==f&&+d)&&Ob.exec(n.css(c.elem,a)),h=1,i=20;if(g&&g[3]!==f){f=f||g[3],e=e||[],g=+d||1;do h=h||".5",g/=h,n.style(c.elem,a,g+f);while(h!==(h=c.cur()/d)&&1!==h&&--i)}return e&&(g=c.start=+g||+d||0,c.unit=f,c.end=e[1]?g+(e[1]+1)*e[2]:+e[2]),c}]};function Sb(){return setTimeout(function(){Lb=void 0}),Lb=n.now()}function Tb(a,b){var c,d=0,e={height:a};for(b=b?1:0;4>d;d+=2-b)c=R[d],e["margin"+c]=e["padding"+c]=a;return b&&(e.opacity=e.width=a),e}function Ub(a,b,c){for(var d,e=(Rb[b]||[]).concat(Rb["*"]),f=0,g=e.length;g>f;f++)if(d=e[f].call(c,b,a))return d}function Vb(a,b,c){var d,e,f,g,h,i,j,k,l=this,m={},o=a.style,p=a.nodeType&&S(a),q=L.get(a,"fxshow");c.queue||(h=n._queueHooks(a,"fx"),null==h.unqueued&&(h.unqueued=0,i=h.empty.fire,h.empty.fire=function(){h.unqueued||i()}),h.unqueued++,l.always(function(){l.always(function(){h.unqueued--,n.queue(a,"fx").length||h.empty.fire()})})),1===a.nodeType&&("height"in b||"width"in b)&&(c.overflow=[o.overflow,o.overflowX,o.overflowY],j=n.css(a,"display"),k="none"===j?L.get(a,"olddisplay")||tb(a.nodeName):j,"inline"===k&&"none"===n.css(a,"float")&&(o.display="inline-block")),c.overflow&&(o.overflow="hidden",l.always(function(){o.overflow=c.overflow[0],o.overflowX=c.overflow[1],o.overflowY=c.overflow[2]}));for(d in b)if(e=b[d],Nb.exec(e)){if(delete b[d],f=f||"toggle"===e,e===(p?"hide":"show")){if("show"!==e||!q||void 0===q[d])continue;p=!0}m[d]=q&&q[d]||n.style(a,d)}else j=void 0;if(n.isEmptyObject(m))"inline"===("none"===j?tb(a.nodeName):j)&&(o.display=j);else{q?"hidden"in q&&(p=q.hidden):q=L.access(a,"fxshow",{}),f&&(q.hidden=!p),p?n(a).show():l.done(function(){n(a).hide()}),l.done(function(){var b;L.remove(a,"fxshow");for(b in m)n.style(a,b,m[b])});for(d in m)g=Ub(p?q[d]:0,d,l),d in q||(q[d]=g.start,p&&(g.end=g.start,g.start="width"===d||"height"===d?1:0))}}function Wb(a,b){var c,d,e,f,g;for(c in a)if(d=n.camelCase(c),e=b[d],f=a[c],n.isArray(f)&&(e=f[1],f=a[c]=f[0]),c!==d&&(a[d]=f,delete a[c]),g=n.cssHooks[d],g&&"expand"in g){f=g.expand(f),delete a[d];for(c in f)c in a||(a[c]=f[c],b[c]=e)}else b[d]=e}function Xb(a,b,c){var d,e,f=0,g=Qb.length,h=n.Deferred().always(function(){delete i.elem}),i=function(){if(e)return!1;for(var b=Lb||Sb(),c=Math.max(0,j.startTime+j.duration-b),d=c/j.duration||0,f=1-d,g=0,i=j.tweens.length;i>g;g++)j.tweens[g].run(f);return h.notifyWith(a,[j,f,c]),1>f&&i?c:(h.resolveWith(a,[j]),!1)},j=h.promise({elem:a,props:n.extend({},b),opts:n.extend(!0,{specialEasing:{}},c),originalProperties:b,originalOptions:c,startTime:Lb||Sb(),duration:c.duration,tweens:[],createTween:function(b,c){var d=n.Tween(a,j.opts,b,c,j.opts.specialEasing[b]||j.opts.easing);return j.tweens.push(d),d},stop:function(b){var c=0,d=b?j.tweens.length:0;if(e)return this;for(e=!0;d>c;c++)j.tweens[c].run(1);return b?h.resolveWith(a,[j,b]):h.rejectWith(a,[j,b]),this}}),k=j.props;for(Wb(k,j.opts.specialEasing);g>f;f++)if(d=Qb[f].call(j,a,k,j.opts))return d;return n.map(k,Ub,j),n.isFunction(j.opts.start)&&j.opts.start.call(a,j),n.fx.timer(n.extend(i,{elem:a,anim:j,queue:j.opts.queue})),j.progress(j.opts.progress).done(j.opts.done,j.opts.complete).fail(j.opts.fail).always(j.opts.always)}n.Animation=n.extend(Xb,{tweener:function(a,b){n.isFunction(a)?(b=a,a=["*"]):a=a.split(" ");for(var c,d=0,e=a.length;e>d;d++)c=a[d],Rb[c]=Rb[c]||[],Rb[c].unshift(b)},prefilter:function(a,b){b?Qb.unshift(a):Qb.push(a)}}),n.speed=function(a,b,c){var d=a&&"object"==typeof a?n.extend({},a):{complete:c||!c&&b||n.isFunction(a)&&a,duration:a,easing:c&&b||b&&!n.isFunction(b)&&b};return d.duration=n.fx.off?0:"number"==typeof d.duration?d.duration:d.duration in n.fx.speeds?n.fx.speeds[d.duration]:n.fx.speeds._default,(null==d.queue||d.queue===!0)&&(d.queue="fx"),d.old=d.complete,d.complete=function(){n.isFunction(d.old)&&d.old.call(this),d.queue&&n.dequeue(this,d.queue)},d},n.fn.extend({fadeTo:function(a,b,c,d){return this.filter(S).css("opacity",0).show().end().animate({opacity:b},a,c,d)},animate:function(a,b,c,d){var e=n.isEmptyObject(a),f=n.speed(b,c,d),g=function(){var b=Xb(this,n.extend({},a),f);(e||L.get(this,"finish"))&&b.stop(!0)};return g.finish=g,e||f.queue===!1?this.each(g):this.queue(f.queue,g)},stop:function(a,b,c){var d=function(a){var b=a.stop;delete a.stop,b(c)};return"string"!=typeof a&&(c=b,b=a,a=void 0),b&&a!==!1&&this.queue(a||"fx",[]),this.each(function(){var b=!0,e=null!=a&&a+"queueHooks",f=n.timers,g=L.get(this);if(e)g[e]&&g[e].stop&&d(g[e]);else for(e in g)g[e]&&g[e].stop&&Pb.test(e)&&d(g[e]);for(e=f.length;e--;)f[e].elem!==this||null!=a&&f[e].queue!==a||(f[e].anim.stop(c),b=!1,f.splice(e,1));(b||!c)&&n.dequeue(this,a)})},finish:function(a){return a!==!1&&(a=a||"fx"),this.each(function(){var b,c=L.get(this),d=c[a+"queue"],e=c[a+"queueHooks"],f=n.timers,g=d?d.length:0;for(c.finish=!0,n.queue(this,a,[]),e&&e.stop&&e.stop.call(this,!0),b=f.length;b--;)f[b].elem===this&&f[b].queue===a&&(f[b].anim.stop(!0),f.splice(b,1));for(b=0;g>b;b++)d[b]&&d[b].finish&&d[b].finish.call(this);delete c.finish})}}),n.each(["toggle","show","hide"],function(a,b){var c=n.fn[b];n.fn[b]=function(a,d,e){return null==a||"boolean"==typeof a?c.apply(this,arguments):this.animate(Tb(b,!0),a,d,e)}}),n.each({slideDown:Tb("show"),slideUp:Tb("hide"),slideToggle:Tb("toggle"),fadeIn:{opacity:"show"},fadeOut:{opacity:"hide"},fadeToggle:{opacity:"toggle"}},function(a,b){n.fn[a]=function(a,c,d){return this.animate(b,a,c,d)}}),n.timers=[],n.fx.tick=function(){var a,b=0,c=n.timers;for(Lb=n.now();b<c.length;b++)a=c[b],a()||c[b]!==a||c.splice(b--,1);c.length||n.fx.stop(),Lb=void 0},n.fx.timer=function(a){n.timers.push(a),a()?n.fx.start():n.timers.pop()},n.fx.interval=13,n.fx.start=function(){Mb||(Mb=setInterval(n.fx.tick,n.fx.interval))},n.fx.stop=function(){clearInterval(Mb),Mb=null},n.fx.speeds={slow:600,fast:200,_default:400},n.fn.delay=function(a,b){return a=n.fx?n.fx.speeds[a]||a:a,b=b||"fx",this.queue(b,function(b,c){var d=setTimeout(b,a);c.stop=function(){clearTimeout(d)}})},function(){var a=l.createElement("input"),b=l.createElement("select"),c=b.appendChild(l.createElement("option"));a.type="checkbox",k.checkOn=""!==a.value,k.optSelected=c.selected,b.disabled=!0,k.optDisabled=!c.disabled,a=l.createElement("input"),a.value="t",a.type="radio",k.radioValue="t"===a.value}();var Yb,Zb,$b=n.expr.attrHandle;n.fn.extend({attr:function(a,b){return J(this,n.attr,a,b,arguments.length>1)},removeAttr:function(a){return this.each(function(){n.removeAttr(this,a)})}}),n.extend({attr:function(a,b,c){var d,e,f=a.nodeType;if(a&&3!==f&&8!==f&&2!==f)return typeof a.getAttribute===U?n.prop(a,b,c):(1===f&&n.isXMLDoc(a)||(b=b.toLowerCase(),d=n.attrHooks[b]||(n.expr.match.bool.test(b)?Zb:Yb)),void 0===c?d&&"get"in d&&null!==(e=d.get(a,b))?e:(e=n.find.attr(a,b),null==e?void 0:e):null!==c?d&&"set"in d&&void 0!==(e=d.set(a,c,b))?e:(a.setAttribute(b,c+""),c):void n.removeAttr(a,b)) -},removeAttr:function(a,b){var c,d,e=0,f=b&&b.match(E);if(f&&1===a.nodeType)while(c=f[e++])d=n.propFix[c]||c,n.expr.match.bool.test(c)&&(a[d]=!1),a.removeAttribute(c)},attrHooks:{type:{set:function(a,b){if(!k.radioValue&&"radio"===b&&n.nodeName(a,"input")){var c=a.value;return a.setAttribute("type",b),c&&(a.value=c),b}}}}}),Zb={set:function(a,b,c){return b===!1?n.removeAttr(a,c):a.setAttribute(c,c),c}},n.each(n.expr.match.bool.source.match(/\w+/g),function(a,b){var c=$b[b]||n.find.attr;$b[b]=function(a,b,d){var e,f;return d||(f=$b[b],$b[b]=e,e=null!=c(a,b,d)?b.toLowerCase():null,$b[b]=f),e}});var _b=/^(?:input|select|textarea|button)$/i;n.fn.extend({prop:function(a,b){return J(this,n.prop,a,b,arguments.length>1)},removeProp:function(a){return this.each(function(){delete this[n.propFix[a]||a]})}}),n.extend({propFix:{"for":"htmlFor","class":"className"},prop:function(a,b,c){var d,e,f,g=a.nodeType;if(a&&3!==g&&8!==g&&2!==g)return f=1!==g||!n.isXMLDoc(a),f&&(b=n.propFix[b]||b,e=n.propHooks[b]),void 0!==c?e&&"set"in e&&void 0!==(d=e.set(a,c,b))?d:a[b]=c:e&&"get"in e&&null!==(d=e.get(a,b))?d:a[b]},propHooks:{tabIndex:{get:function(a){return a.hasAttribute("tabindex")||_b.test(a.nodeName)||a.href?a.tabIndex:-1}}}}),k.optSelected||(n.propHooks.selected={get:function(a){var b=a.parentNode;return b&&b.parentNode&&b.parentNode.selectedIndex,null}}),n.each(["tabIndex","readOnly","maxLength","cellSpacing","cellPadding","rowSpan","colSpan","useMap","frameBorder","contentEditable"],function(){n.propFix[this.toLowerCase()]=this});var ac=/[\t\r\n\f]/g;n.fn.extend({addClass:function(a){var b,c,d,e,f,g,h="string"==typeof a&&a,i=0,j=this.length;if(n.isFunction(a))return this.each(function(b){n(this).addClass(a.call(this,b,this.className))});if(h)for(b=(a||"").match(E)||[];j>i;i++)if(c=this[i],d=1===c.nodeType&&(c.className?(" "+c.className+" ").replace(ac," "):" ")){f=0;while(e=b[f++])d.indexOf(" "+e+" ")<0&&(d+=e+" ");g=n.trim(d),c.className!==g&&(c.className=g)}return this},removeClass:function(a){var b,c,d,e,f,g,h=0===arguments.length||"string"==typeof a&&a,i=0,j=this.length;if(n.isFunction(a))return this.each(function(b){n(this).removeClass(a.call(this,b,this.className))});if(h)for(b=(a||"").match(E)||[];j>i;i++)if(c=this[i],d=1===c.nodeType&&(c.className?(" "+c.className+" ").replace(ac," "):"")){f=0;while(e=b[f++])while(d.indexOf(" "+e+" ")>=0)d=d.replace(" "+e+" "," ");g=a?n.trim(d):"",c.className!==g&&(c.className=g)}return this},toggleClass:function(a,b){var c=typeof a;return"boolean"==typeof b&&"string"===c?b?this.addClass(a):this.removeClass(a):this.each(n.isFunction(a)?function(c){n(this).toggleClass(a.call(this,c,this.className,b),b)}:function(){if("string"===c){var b,d=0,e=n(this),f=a.match(E)||[];while(b=f[d++])e.hasClass(b)?e.removeClass(b):e.addClass(b)}else(c===U||"boolean"===c)&&(this.className&&L.set(this,"__className__",this.className),this.className=this.className||a===!1?"":L.get(this,"__className__")||"")})},hasClass:function(a){for(var b=" "+a+" ",c=0,d=this.length;d>c;c++)if(1===this[c].nodeType&&(" "+this[c].className+" ").replace(ac," ").indexOf(b)>=0)return!0;return!1}});var bc=/\r/g;n.fn.extend({val:function(a){var b,c,d,e=this[0];{if(arguments.length)return d=n.isFunction(a),this.each(function(c){var e;1===this.nodeType&&(e=d?a.call(this,c,n(this).val()):a,null==e?e="":"number"==typeof e?e+="":n.isArray(e)&&(e=n.map(e,function(a){return null==a?"":a+""})),b=n.valHooks[this.type]||n.valHooks[this.nodeName.toLowerCase()],b&&"set"in b&&void 0!==b.set(this,e,"value")||(this.value=e))});if(e)return b=n.valHooks[e.type]||n.valHooks[e.nodeName.toLowerCase()],b&&"get"in b&&void 0!==(c=b.get(e,"value"))?c:(c=e.value,"string"==typeof c?c.replace(bc,""):null==c?"":c)}}}),n.extend({valHooks:{option:{get:function(a){var b=n.find.attr(a,"value");return null!=b?b:n.trim(n.text(a))}},select:{get:function(a){for(var b,c,d=a.options,e=a.selectedIndex,f="select-one"===a.type||0>e,g=f?null:[],h=f?e+1:d.length,i=0>e?h:f?e:0;h>i;i++)if(c=d[i],!(!c.selected&&i!==e||(k.optDisabled?c.disabled:null!==c.getAttribute("disabled"))||c.parentNode.disabled&&n.nodeName(c.parentNode,"optgroup"))){if(b=n(c).val(),f)return b;g.push(b)}return g},set:function(a,b){var c,d,e=a.options,f=n.makeArray(b),g=e.length;while(g--)d=e[g],(d.selected=n.inArray(d.value,f)>=0)&&(c=!0);return c||(a.selectedIndex=-1),f}}}}),n.each(["radio","checkbox"],function(){n.valHooks[this]={set:function(a,b){return n.isArray(b)?a.checked=n.inArray(n(a).val(),b)>=0:void 0}},k.checkOn||(n.valHooks[this].get=function(a){return null===a.getAttribute("value")?"on":a.value})}),n.each("blur focus focusin focusout load resize scroll unload click dblclick mousedown mouseup mousemove mouseover mouseout mouseenter mouseleave change select submit keydown keypress keyup error contextmenu".split(" "),function(a,b){n.fn[b]=function(a,c){return arguments.length>0?this.on(b,null,a,c):this.trigger(b)}}),n.fn.extend({hover:function(a,b){return this.mouseenter(a).mouseleave(b||a)},bind:function(a,b,c){return this.on(a,null,b,c)},unbind:function(a,b){return this.off(a,null,b)},delegate:function(a,b,c,d){return this.on(b,a,c,d)},undelegate:function(a,b,c){return 1===arguments.length?this.off(a,"**"):this.off(b,a||"**",c)}});var cc=n.now(),dc=/\?/;n.parseJSON=function(a){return JSON.parse(a+"")},n.parseXML=function(a){var b,c;if(!a||"string"!=typeof a)return null;try{c=new DOMParser,b=c.parseFromString(a,"text/xml")}catch(d){b=void 0}return(!b||b.getElementsByTagName("parsererror").length)&&n.error("Invalid XML: "+a),b};var ec=/#.*$/,fc=/([?&])_=[^&]*/,gc=/^(.*?):[ \t]*([^\r\n]*)$/gm,hc=/^(?:about|app|app-storage|.+-extension|file|res|widget):$/,ic=/^(?:GET|HEAD)$/,jc=/^\/\//,kc=/^([\w.+-]+:)(?:\/\/(?:[^\/?#]*@|)([^\/?#:]*)(?::(\d+)|)|)/,lc={},mc={},nc="*/".concat("*"),oc=a.location.href,pc=kc.exec(oc.toLowerCase())||[];function qc(a){return function(b,c){"string"!=typeof b&&(c=b,b="*");var d,e=0,f=b.toLowerCase().match(E)||[];if(n.isFunction(c))while(d=f[e++])"+"===d[0]?(d=d.slice(1)||"*",(a[d]=a[d]||[]).unshift(c)):(a[d]=a[d]||[]).push(c)}}function rc(a,b,c,d){var e={},f=a===mc;function g(h){var i;return e[h]=!0,n.each(a[h]||[],function(a,h){var j=h(b,c,d);return"string"!=typeof j||f||e[j]?f?!(i=j):void 0:(b.dataTypes.unshift(j),g(j),!1)}),i}return g(b.dataTypes[0])||!e["*"]&&g("*")}function sc(a,b){var c,d,e=n.ajaxSettings.flatOptions||{};for(c in b)void 0!==b[c]&&((e[c]?a:d||(d={}))[c]=b[c]);return d&&n.extend(!0,a,d),a}function tc(a,b,c){var d,e,f,g,h=a.contents,i=a.dataTypes;while("*"===i[0])i.shift(),void 0===d&&(d=a.mimeType||b.getResponseHeader("Content-Type"));if(d)for(e in h)if(h[e]&&h[e].test(d)){i.unshift(e);break}if(i[0]in c)f=i[0];else{for(e in c){if(!i[0]||a.converters[e+" "+i[0]]){f=e;break}g||(g=e)}f=f||g}return f?(f!==i[0]&&i.unshift(f),c[f]):void 0}function uc(a,b,c,d){var e,f,g,h,i,j={},k=a.dataTypes.slice();if(k[1])for(g in a.converters)j[g.toLowerCase()]=a.converters[g];f=k.shift();while(f)if(a.responseFields[f]&&(c[a.responseFields[f]]=b),!i&&d&&a.dataFilter&&(b=a.dataFilter(b,a.dataType)),i=f,f=k.shift())if("*"===f)f=i;else if("*"!==i&&i!==f){if(g=j[i+" "+f]||j["* "+f],!g)for(e in j)if(h=e.split(" "),h[1]===f&&(g=j[i+" "+h[0]]||j["* "+h[0]])){g===!0?g=j[e]:j[e]!==!0&&(f=h[0],k.unshift(h[1]));break}if(g!==!0)if(g&&a["throws"])b=g(b);else try{b=g(b)}catch(l){return{state:"parsererror",error:g?l:"No conversion from "+i+" to "+f}}}return{state:"success",data:b}}n.extend({active:0,lastModified:{},etag:{},ajaxSettings:{url:oc,type:"GET",isLocal:hc.test(pc[1]),global:!0,processData:!0,async:!0,contentType:"application/x-www-form-urlencoded; charset=UTF-8",accepts:{"*":nc,text:"text/plain",html:"text/html",xml:"application/xml, text/xml",json:"application/json, text/javascript"},contents:{xml:/xml/,html:/html/,json:/json/},responseFields:{xml:"responseXML",text:"responseText",json:"responseJSON"},converters:{"* text":String,"text html":!0,"text json":n.parseJSON,"text xml":n.parseXML},flatOptions:{url:!0,context:!0}},ajaxSetup:function(a,b){return b?sc(sc(a,n.ajaxSettings),b):sc(n.ajaxSettings,a)},ajaxPrefilter:qc(lc),ajaxTransport:qc(mc),ajax:function(a,b){"object"==typeof a&&(b=a,a=void 0),b=b||{};var c,d,e,f,g,h,i,j,k=n.ajaxSetup({},b),l=k.context||k,m=k.context&&(l.nodeType||l.jquery)?n(l):n.event,o=n.Deferred(),p=n.Callbacks("once memory"),q=k.statusCode||{},r={},s={},t=0,u="canceled",v={readyState:0,getResponseHeader:function(a){var b;if(2===t){if(!f){f={};while(b=gc.exec(e))f[b[1].toLowerCase()]=b[2]}b=f[a.toLowerCase()]}return null==b?null:b},getAllResponseHeaders:function(){return 2===t?e:null},setRequestHeader:function(a,b){var c=a.toLowerCase();return t||(a=s[c]=s[c]||a,r[a]=b),this},overrideMimeType:function(a){return t||(k.mimeType=a),this},statusCode:function(a){var b;if(a)if(2>t)for(b in a)q[b]=[q[b],a[b]];else v.always(a[v.status]);return this},abort:function(a){var b=a||u;return c&&c.abort(b),x(0,b),this}};if(o.promise(v).complete=p.add,v.success=v.done,v.error=v.fail,k.url=((a||k.url||oc)+"").replace(ec,"").replace(jc,pc[1]+"//"),k.type=b.method||b.type||k.method||k.type,k.dataTypes=n.trim(k.dataType||"*").toLowerCase().match(E)||[""],null==k.crossDomain&&(h=kc.exec(k.url.toLowerCase()),k.crossDomain=!(!h||h[1]===pc[1]&&h[2]===pc[2]&&(h[3]||("http:"===h[1]?"80":"443"))===(pc[3]||("http:"===pc[1]?"80":"443")))),k.data&&k.processData&&"string"!=typeof k.data&&(k.data=n.param(k.data,k.traditional)),rc(lc,k,b,v),2===t)return v;i=n.event&&k.global,i&&0===n.active++&&n.event.trigger("ajaxStart"),k.type=k.type.toUpperCase(),k.hasContent=!ic.test(k.type),d=k.url,k.hasContent||(k.data&&(d=k.url+=(dc.test(d)?"&":"?")+k.data,delete k.data),k.cache===!1&&(k.url=fc.test(d)?d.replace(fc,"$1_="+cc++):d+(dc.test(d)?"&":"?")+"_="+cc++)),k.ifModified&&(n.lastModified[d]&&v.setRequestHeader("If-Modified-Since",n.lastModified[d]),n.etag[d]&&v.setRequestHeader("If-None-Match",n.etag[d])),(k.data&&k.hasContent&&k.contentType!==!1||b.contentType)&&v.setRequestHeader("Content-Type",k.contentType),v.setRequestHeader("Accept",k.dataTypes[0]&&k.accepts[k.dataTypes[0]]?k.accepts[k.dataTypes[0]]+("*"!==k.dataTypes[0]?", "+nc+"; q=0.01":""):k.accepts["*"]);for(j in k.headers)v.setRequestHeader(j,k.headers[j]);if(k.beforeSend&&(k.beforeSend.call(l,v,k)===!1||2===t))return v.abort();u="abort";for(j in{success:1,error:1,complete:1})v[j](k[j]);if(c=rc(mc,k,b,v)){v.readyState=1,i&&m.trigger("ajaxSend",[v,k]),k.async&&k.timeout>0&&(g=setTimeout(function(){v.abort("timeout")},k.timeout));try{t=1,c.send(r,x)}catch(w){if(!(2>t))throw w;x(-1,w)}}else x(-1,"No Transport");function x(a,b,f,h){var j,r,s,u,w,x=b;2!==t&&(t=2,g&&clearTimeout(g),c=void 0,e=h||"",v.readyState=a>0?4:0,j=a>=200&&300>a||304===a,f&&(u=tc(k,v,f)),u=uc(k,u,v,j),j?(k.ifModified&&(w=v.getResponseHeader("Last-Modified"),w&&(n.lastModified[d]=w),w=v.getResponseHeader("etag"),w&&(n.etag[d]=w)),204===a||"HEAD"===k.type?x="nocontent":304===a?x="notmodified":(x=u.state,r=u.data,s=u.error,j=!s)):(s=x,(a||!x)&&(x="error",0>a&&(a=0))),v.status=a,v.statusText=(b||x)+"",j?o.resolveWith(l,[r,x,v]):o.rejectWith(l,[v,x,s]),v.statusCode(q),q=void 0,i&&m.trigger(j?"ajaxSuccess":"ajaxError",[v,k,j?r:s]),p.fireWith(l,[v,x]),i&&(m.trigger("ajaxComplete",[v,k]),--n.active||n.event.trigger("ajaxStop")))}return v},getJSON:function(a,b,c){return n.get(a,b,c,"json")},getScript:function(a,b){return n.get(a,void 0,b,"script")}}),n.each(["get","post"],function(a,b){n[b]=function(a,c,d,e){return n.isFunction(c)&&(e=e||d,d=c,c=void 0),n.ajax({url:a,type:b,dataType:e,data:c,success:d})}}),n._evalUrl=function(a){return n.ajax({url:a,type:"GET",dataType:"script",async:!1,global:!1,"throws":!0})},n.fn.extend({wrapAll:function(a){var b;return n.isFunction(a)?this.each(function(b){n(this).wrapAll(a.call(this,b))}):(this[0]&&(b=n(a,this[0].ownerDocument).eq(0).clone(!0),this[0].parentNode&&b.insertBefore(this[0]),b.map(function(){var a=this;while(a.firstElementChild)a=a.firstElementChild;return a}).append(this)),this)},wrapInner:function(a){return this.each(n.isFunction(a)?function(b){n(this).wrapInner(a.call(this,b))}:function(){var b=n(this),c=b.contents();c.length?c.wrapAll(a):b.append(a)})},wrap:function(a){var b=n.isFunction(a);return this.each(function(c){n(this).wrapAll(b?a.call(this,c):a)})},unwrap:function(){return this.parent().each(function(){n.nodeName(this,"body")||n(this).replaceWith(this.childNodes)}).end()}}),n.expr.filters.hidden=function(a){return a.offsetWidth<=0&&a.offsetHeight<=0},n.expr.filters.visible=function(a){return!n.expr.filters.hidden(a)};var vc=/%20/g,wc=/\[\]$/,xc=/\r?\n/g,yc=/^(?:submit|button|image|reset|file)$/i,zc=/^(?:input|select|textarea|keygen)/i;function Ac(a,b,c,d){var e;if(n.isArray(b))n.each(b,function(b,e){c||wc.test(a)?d(a,e):Ac(a+"["+("object"==typeof e?b:"")+"]",e,c,d)});else if(c||"object"!==n.type(b))d(a,b);else for(e in b)Ac(a+"["+e+"]",b[e],c,d)}n.param=function(a,b){var c,d=[],e=function(a,b){b=n.isFunction(b)?b():null==b?"":b,d[d.length]=encodeURIComponent(a)+"="+encodeURIComponent(b)};if(void 0===b&&(b=n.ajaxSettings&&n.ajaxSettings.traditional),n.isArray(a)||a.jquery&&!n.isPlainObject(a))n.each(a,function(){e(this.name,this.value)});else for(c in a)Ac(c,a[c],b,e);return d.join("&").replace(vc,"+")},n.fn.extend({serialize:function(){return n.param(this.serializeArray())},serializeArray:function(){return this.map(function(){var a=n.prop(this,"elements");return a?n.makeArray(a):this}).filter(function(){var a=this.type;return this.name&&!n(this).is(":disabled")&&zc.test(this.nodeName)&&!yc.test(a)&&(this.checked||!T.test(a))}).map(function(a,b){var c=n(this).val();return null==c?null:n.isArray(c)?n.map(c,function(a){return{name:b.name,value:a.replace(xc,"\r\n")}}):{name:b.name,value:c.replace(xc,"\r\n")}}).get()}}),n.ajaxSettings.xhr=function(){try{return new XMLHttpRequest}catch(a){}};var Bc=0,Cc={},Dc={0:200,1223:204},Ec=n.ajaxSettings.xhr();a.attachEvent&&a.attachEvent("onunload",function(){for(var a in Cc)Cc[a]()}),k.cors=!!Ec&&"withCredentials"in Ec,k.ajax=Ec=!!Ec,n.ajaxTransport(function(a){var b;return k.cors||Ec&&!a.crossDomain?{send:function(c,d){var e,f=a.xhr(),g=++Bc;if(f.open(a.type,a.url,a.async,a.username,a.password),a.xhrFields)for(e in a.xhrFields)f[e]=a.xhrFields[e];a.mimeType&&f.overrideMimeType&&f.overrideMimeType(a.mimeType),a.crossDomain||c["X-Requested-With"]||(c["X-Requested-With"]="XMLHttpRequest");for(e in c)f.setRequestHeader(e,c[e]);b=function(a){return function(){b&&(delete Cc[g],b=f.onload=f.onerror=null,"abort"===a?f.abort():"error"===a?d(f.status,f.statusText):d(Dc[f.status]||f.status,f.statusText,"string"==typeof f.responseText?{text:f.responseText}:void 0,f.getAllResponseHeaders()))}},f.onload=b(),f.onerror=b("error"),b=Cc[g]=b("abort");try{f.send(a.hasContent&&a.data||null)}catch(h){if(b)throw h}},abort:function(){b&&b()}}:void 0}),n.ajaxSetup({accepts:{script:"text/javascript, application/javascript, application/ecmascript, application/x-ecmascript"},contents:{script:/(?:java|ecma)script/},converters:{"text script":function(a){return n.globalEval(a),a}}}),n.ajaxPrefilter("script",function(a){void 0===a.cache&&(a.cache=!1),a.crossDomain&&(a.type="GET")}),n.ajaxTransport("script",function(a){if(a.crossDomain){var b,c;return{send:function(d,e){b=n("<script>").prop({async:!0,charset:a.scriptCharset,src:a.url}).on("load error",c=function(a){b.remove(),c=null,a&&e("error"===a.type?404:200,a.type)}),l.head.appendChild(b[0])},abort:function(){c&&c()}}}});var Fc=[],Gc=/(=)\?(?=&|$)|\?\?/;n.ajaxSetup({jsonp:"callback",jsonpCallback:function(){var a=Fc.pop()||n.expando+"_"+cc++;return this[a]=!0,a}}),n.ajaxPrefilter("json jsonp",function(b,c,d){var e,f,g,h=b.jsonp!==!1&&(Gc.test(b.url)?"url":"string"==typeof b.data&&!(b.contentType||"").indexOf("application/x-www-form-urlencoded")&&Gc.test(b.data)&&"data");return h||"jsonp"===b.dataTypes[0]?(e=b.jsonpCallback=n.isFunction(b.jsonpCallback)?b.jsonpCallback():b.jsonpCallback,h?b[h]=b[h].replace(Gc,"$1"+e):b.jsonp!==!1&&(b.url+=(dc.test(b.url)?"&":"?")+b.jsonp+"="+e),b.converters["script json"]=function(){return g||n.error(e+" was not called"),g[0]},b.dataTypes[0]="json",f=a[e],a[e]=function(){g=arguments},d.always(function(){a[e]=f,b[e]&&(b.jsonpCallback=c.jsonpCallback,Fc.push(e)),g&&n.isFunction(f)&&f(g[0]),g=f=void 0}),"script"):void 0}),n.parseHTML=function(a,b,c){if(!a||"string"!=typeof a)return null;"boolean"==typeof b&&(c=b,b=!1),b=b||l;var d=v.exec(a),e=!c&&[];return d?[b.createElement(d[1])]:(d=n.buildFragment([a],b,e),e&&e.length&&n(e).remove(),n.merge([],d.childNodes))};var Hc=n.fn.load;n.fn.load=function(a,b,c){if("string"!=typeof a&&Hc)return Hc.apply(this,arguments);var d,e,f,g=this,h=a.indexOf(" ");return h>=0&&(d=n.trim(a.slice(h)),a=a.slice(0,h)),n.isFunction(b)?(c=b,b=void 0):b&&"object"==typeof b&&(e="POST"),g.length>0&&n.ajax({url:a,type:e,dataType:"html",data:b}).done(function(a){f=arguments,g.html(d?n("<div>").append(n.parseHTML(a)).find(d):a)}).complete(c&&function(a,b){g.each(c,f||[a.responseText,b,a])}),this},n.each(["ajaxStart","ajaxStop","ajaxComplete","ajaxError","ajaxSuccess","ajaxSend"],function(a,b){n.fn[b]=function(a){return this.on(b,a)}}),n.expr.filters.animated=function(a){return n.grep(n.timers,function(b){return a===b.elem}).length};var Ic=a.document.documentElement;function Jc(a){return n.isWindow(a)?a:9===a.nodeType&&a.defaultView}n.offset={setOffset:function(a,b,c){var d,e,f,g,h,i,j,k=n.css(a,"position"),l=n(a),m={};"static"===k&&(a.style.position="relative"),h=l.offset(),f=n.css(a,"top"),i=n.css(a,"left"),j=("absolute"===k||"fixed"===k)&&(f+i).indexOf("auto")>-1,j?(d=l.position(),g=d.top,e=d.left):(g=parseFloat(f)||0,e=parseFloat(i)||0),n.isFunction(b)&&(b=b.call(a,c,h)),null!=b.top&&(m.top=b.top-h.top+g),null!=b.left&&(m.left=b.left-h.left+e),"using"in b?b.using.call(a,m):l.css(m)}},n.fn.extend({offset:function(a){if(arguments.length)return void 0===a?this:this.each(function(b){n.offset.setOffset(this,a,b)});var b,c,d=this[0],e={top:0,left:0},f=d&&d.ownerDocument;if(f)return b=f.documentElement,n.contains(b,d)?(typeof d.getBoundingClientRect!==U&&(e=d.getBoundingClientRect()),c=Jc(f),{top:e.top+c.pageYOffset-b.clientTop,left:e.left+c.pageXOffset-b.clientLeft}):e},position:function(){if(this[0]){var a,b,c=this[0],d={top:0,left:0};return"fixed"===n.css(c,"position")?b=c.getBoundingClientRect():(a=this.offsetParent(),b=this.offset(),n.nodeName(a[0],"html")||(d=a.offset()),d.top+=n.css(a[0],"borderTopWidth",!0),d.left+=n.css(a[0],"borderLeftWidth",!0)),{top:b.top-d.top-n.css(c,"marginTop",!0),left:b.left-d.left-n.css(c,"marginLeft",!0)}}},offsetParent:function(){return this.map(function(){var a=this.offsetParent||Ic;while(a&&!n.nodeName(a,"html")&&"static"===n.css(a,"position"))a=a.offsetParent;return a||Ic})}}),n.each({scrollLeft:"pageXOffset",scrollTop:"pageYOffset"},function(b,c){var d="pageYOffset"===c;n.fn[b]=function(e){return J(this,function(b,e,f){var g=Jc(b);return void 0===f?g?g[c]:b[e]:void(g?g.scrollTo(d?a.pageXOffset:f,d?f:a.pageYOffset):b[e]=f)},b,e,arguments.length,null)}}),n.each(["top","left"],function(a,b){n.cssHooks[b]=yb(k.pixelPosition,function(a,c){return c?(c=xb(a,b),vb.test(c)?n(a).position()[b]+"px":c):void 0})}),n.each({Height:"height",Width:"width"},function(a,b){n.each({padding:"inner"+a,content:b,"":"outer"+a},function(c,d){n.fn[d]=function(d,e){var f=arguments.length&&(c||"boolean"!=typeof d),g=c||(d===!0||e===!0?"margin":"border");return J(this,function(b,c,d){var e;return n.isWindow(b)?b.document.documentElement["client"+a]:9===b.nodeType?(e=b.documentElement,Math.max(b.body["scroll"+a],e["scroll"+a],b.body["offset"+a],e["offset"+a],e["client"+a])):void 0===d?n.css(b,c,g):n.style(b,c,d,g)},b,f?d:void 0,f,null)}})}),n.fn.size=function(){return this.length},n.fn.andSelf=n.fn.addBack,"function"==typeof define&&define.amd&&define("jquery",[],function(){return n});var Kc=a.jQuery,Lc=a.$;return n.noConflict=function(b){return a.$===n&&(a.$=Lc),b&&a.jQuery===n&&(a.jQuery=Kc),n},typeof b===U&&(a.jQuery=a.$=n),n}); diff --git a/synapse/static/client/login/js/jquery-3.4.1.min.js b/synapse/static/client/login/js/jquery-3.4.1.min.js new file mode 100644 index 0000000000..a1c07fd803 --- /dev/null +++ b/synapse/static/client/login/js/jquery-3.4.1.min.js @@ -0,0 +1,2 @@ +/*! jQuery v3.4.1 | (c) JS Foundation and other contributors | jquery.org/license */ +!function(e,t){"use strict";"object"==typeof module&&"object"==typeof module.exports?module.exports=e.document?t(e,!0):function(e){if(!e.document)throw new Error("jQuery requires a window with a document");return t(e)}:t(e)}("undefined"!=typeof window?window:this,function(C,e){"use strict";var t=[],E=C.document,r=Object.getPrototypeOf,s=t.slice,g=t.concat,u=t.push,i=t.indexOf,n={},o=n.toString,v=n.hasOwnProperty,a=v.toString,l=a.call(Object),y={},m=function(e){return"function"==typeof e&&"number"!=typeof e.nodeType},x=function(e){return null!=e&&e===e.window},c={type:!0,src:!0,nonce:!0,noModule:!0};function b(e,t,n){var r,i,o=(n=n||E).createElement("script");if(o.text=e,t)for(r in c)(i=t[r]||t.getAttribute&&t.getAttribute(r))&&o.setAttribute(r,i);n.head.appendChild(o).parentNode.removeChild(o)}function w(e){return null==e?e+"":"object"==typeof e||"function"==typeof e?n[o.call(e)]||"object":typeof e}var f="3.4.1",k=function(e,t){return new k.fn.init(e,t)},p=/^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g;function d(e){var t=!!e&&"length"in e&&e.length,n=w(e);return!m(e)&&!x(e)&&("array"===n||0===t||"number"==typeof t&&0<t&&t-1 in e)}k.fn=k.prototype={jquery:f,constructor:k,length:0,toArray:function(){return s.call(this)},get:function(e){return null==e?s.call(this):e<0?this[e+this.length]:this[e]},pushStack:function(e){var t=k.merge(this.constructor(),e);return t.prevObject=this,t},each:function(e){return k.each(this,e)},map:function(n){return this.pushStack(k.map(this,function(e,t){return n.call(e,t,e)}))},slice:function(){return this.pushStack(s.apply(this,arguments))},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},eq:function(e){var t=this.length,n=+e+(e<0?t:0);return this.pushStack(0<=n&&n<t?[this[n]]:[])},end:function(){return this.prevObject||this.constructor()},push:u,sort:t.sort,splice:t.splice},k.extend=k.fn.extend=function(){var e,t,n,r,i,o,a=arguments[0]||{},s=1,u=arguments.length,l=!1;for("boolean"==typeof a&&(l=a,a=arguments[s]||{},s++),"object"==typeof a||m(a)||(a={}),s===u&&(a=this,s--);s<u;s++)if(null!=(e=arguments[s]))for(t in e)r=e[t],"__proto__"!==t&&a!==r&&(l&&r&&(k.isPlainObject(r)||(i=Array.isArray(r)))?(n=a[t],o=i&&!Array.isArray(n)?[]:i||k.isPlainObject(n)?n:{},i=!1,a[t]=k.extend(l,o,r)):void 0!==r&&(a[t]=r));return a},k.extend({expando:"jQuery"+(f+Math.random()).replace(/\D/g,""),isReady:!0,error:function(e){throw new Error(e)},noop:function(){},isPlainObject:function(e){var t,n;return!(!e||"[object Object]"!==o.call(e))&&(!(t=r(e))||"function"==typeof(n=v.call(t,"constructor")&&t.constructor)&&a.call(n)===l)},isEmptyObject:function(e){var t;for(t in e)return!1;return!0},globalEval:function(e,t){b(e,{nonce:t&&t.nonce})},each:function(e,t){var n,r=0;if(d(e)){for(n=e.length;r<n;r++)if(!1===t.call(e[r],r,e[r]))break}else for(r in e)if(!1===t.call(e[r],r,e[r]))break;return e},trim:function(e){return null==e?"":(e+"").replace(p,"")},makeArray:function(e,t){var n=t||[];return null!=e&&(d(Object(e))?k.merge(n,"string"==typeof e?[e]:e):u.call(n,e)),n},inArray:function(e,t,n){return null==t?-1:i.call(t,e,n)},merge:function(e,t){for(var n=+t.length,r=0,i=e.length;r<n;r++)e[i++]=t[r];return e.length=i,e},grep:function(e,t,n){for(var r=[],i=0,o=e.length,a=!n;i<o;i++)!t(e[i],i)!==a&&r.push(e[i]);return r},map:function(e,t,n){var r,i,o=0,a=[];if(d(e))for(r=e.length;o<r;o++)null!=(i=t(e[o],o,n))&&a.push(i);else for(o in e)null!=(i=t(e[o],o,n))&&a.push(i);return g.apply([],a)},guid:1,support:y}),"function"==typeof Symbol&&(k.fn[Symbol.iterator]=t[Symbol.iterator]),k.each("Boolean Number String Function Array Date RegExp Object Error Symbol".split(" "),function(e,t){n["[object "+t+"]"]=t.toLowerCase()});var h=function(n){var e,d,b,o,i,h,f,g,w,u,l,T,C,a,E,v,s,c,y,k="sizzle"+1*new Date,m=n.document,S=0,r=0,p=ue(),x=ue(),N=ue(),A=ue(),D=function(e,t){return e===t&&(l=!0),0},j={}.hasOwnProperty,t=[],q=t.pop,L=t.push,H=t.push,O=t.slice,P=function(e,t){for(var n=0,r=e.length;n<r;n++)if(e[n]===t)return n;return-1},R="checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped",M="[\\x20\\t\\r\\n\\f]",I="(?:\\\\.|[\\w-]|[^\0-\\xa0])+",W="\\["+M+"*("+I+")(?:"+M+"*([*^$|!~]?=)"+M+"*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|("+I+"))|)"+M+"*\\]",$=":("+I+")(?:\\((('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|((?:\\\\.|[^\\\\()[\\]]|"+W+")*)|.*)\\)|)",F=new RegExp(M+"+","g"),B=new RegExp("^"+M+"+|((?:^|[^\\\\])(?:\\\\.)*)"+M+"+$","g"),_=new RegExp("^"+M+"*,"+M+"*"),z=new RegExp("^"+M+"*([>+~]|"+M+")"+M+"*"),U=new RegExp(M+"|>"),X=new RegExp($),V=new RegExp("^"+I+"$"),G={ID:new RegExp("^#("+I+")"),CLASS:new RegExp("^\\.("+I+")"),TAG:new RegExp("^("+I+"|[*])"),ATTR:new RegExp("^"+W),PSEUDO:new RegExp("^"+$),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+M+"*(even|odd|(([+-]|)(\\d*)n|)"+M+"*(?:([+-]|)"+M+"*(\\d+)|))"+M+"*\\)|)","i"),bool:new RegExp("^(?:"+R+")$","i"),needsContext:new RegExp("^"+M+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+M+"*((?:-\\d)?\\d*)"+M+"*\\)|)(?=[^-]|$)","i")},Y=/HTML$/i,Q=/^(?:input|select|textarea|button)$/i,J=/^h\d$/i,K=/^[^{]+\{\s*\[native \w/,Z=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,ee=/[+~]/,te=new RegExp("\\\\([\\da-f]{1,6}"+M+"?|("+M+")|.)","ig"),ne=function(e,t,n){var r="0x"+t-65536;return r!=r||n?t:r<0?String.fromCharCode(r+65536):String.fromCharCode(r>>10|55296,1023&r|56320)},re=/([\0-\x1f\x7f]|^-?\d)|^-$|[^\0-\x1f\x7f-\uFFFF\w-]/g,ie=function(e,t){return t?"\0"===e?"\ufffd":e.slice(0,-1)+"\\"+e.charCodeAt(e.length-1).toString(16)+" ":"\\"+e},oe=function(){T()},ae=be(function(e){return!0===e.disabled&&"fieldset"===e.nodeName.toLowerCase()},{dir:"parentNode",next:"legend"});try{H.apply(t=O.call(m.childNodes),m.childNodes),t[m.childNodes.length].nodeType}catch(e){H={apply:t.length?function(e,t){L.apply(e,O.call(t))}:function(e,t){var n=e.length,r=0;while(e[n++]=t[r++]);e.length=n-1}}}function se(t,e,n,r){var i,o,a,s,u,l,c,f=e&&e.ownerDocument,p=e?e.nodeType:9;if(n=n||[],"string"!=typeof t||!t||1!==p&&9!==p&&11!==p)return n;if(!r&&((e?e.ownerDocument||e:m)!==C&&T(e),e=e||C,E)){if(11!==p&&(u=Z.exec(t)))if(i=u[1]){if(9===p){if(!(a=e.getElementById(i)))return n;if(a.id===i)return n.push(a),n}else if(f&&(a=f.getElementById(i))&&y(e,a)&&a.id===i)return n.push(a),n}else{if(u[2])return H.apply(n,e.getElementsByTagName(t)),n;if((i=u[3])&&d.getElementsByClassName&&e.getElementsByClassName)return H.apply(n,e.getElementsByClassName(i)),n}if(d.qsa&&!A[t+" "]&&(!v||!v.test(t))&&(1!==p||"object"!==e.nodeName.toLowerCase())){if(c=t,f=e,1===p&&U.test(t)){(s=e.getAttribute("id"))?s=s.replace(re,ie):e.setAttribute("id",s=k),o=(l=h(t)).length;while(o--)l[o]="#"+s+" "+xe(l[o]);c=l.join(","),f=ee.test(t)&&ye(e.parentNode)||e}try{return H.apply(n,f.querySelectorAll(c)),n}catch(e){A(t,!0)}finally{s===k&&e.removeAttribute("id")}}}return g(t.replace(B,"$1"),e,n,r)}function ue(){var r=[];return function e(t,n){return r.push(t+" ")>b.cacheLength&&delete e[r.shift()],e[t+" "]=n}}function le(e){return e[k]=!0,e}function ce(e){var t=C.createElement("fieldset");try{return!!e(t)}catch(e){return!1}finally{t.parentNode&&t.parentNode.removeChild(t),t=null}}function fe(e,t){var n=e.split("|"),r=n.length;while(r--)b.attrHandle[n[r]]=t}function pe(e,t){var n=t&&e,r=n&&1===e.nodeType&&1===t.nodeType&&e.sourceIndex-t.sourceIndex;if(r)return r;if(n)while(n=n.nextSibling)if(n===t)return-1;return e?1:-1}function de(t){return function(e){return"input"===e.nodeName.toLowerCase()&&e.type===t}}function he(n){return function(e){var t=e.nodeName.toLowerCase();return("input"===t||"button"===t)&&e.type===n}}function ge(t){return function(e){return"form"in e?e.parentNode&&!1===e.disabled?"label"in e?"label"in e.parentNode?e.parentNode.disabled===t:e.disabled===t:e.isDisabled===t||e.isDisabled!==!t&&ae(e)===t:e.disabled===t:"label"in e&&e.disabled===t}}function ve(a){return le(function(o){return o=+o,le(function(e,t){var n,r=a([],e.length,o),i=r.length;while(i--)e[n=r[i]]&&(e[n]=!(t[n]=e[n]))})})}function ye(e){return e&&"undefined"!=typeof e.getElementsByTagName&&e}for(e in d=se.support={},i=se.isXML=function(e){var t=e.namespaceURI,n=(e.ownerDocument||e).documentElement;return!Y.test(t||n&&n.nodeName||"HTML")},T=se.setDocument=function(e){var t,n,r=e?e.ownerDocument||e:m;return r!==C&&9===r.nodeType&&r.documentElement&&(a=(C=r).documentElement,E=!i(C),m!==C&&(n=C.defaultView)&&n.top!==n&&(n.addEventListener?n.addEventListener("unload",oe,!1):n.attachEvent&&n.attachEvent("onunload",oe)),d.attributes=ce(function(e){return e.className="i",!e.getAttribute("className")}),d.getElementsByTagName=ce(function(e){return e.appendChild(C.createComment("")),!e.getElementsByTagName("*").length}),d.getElementsByClassName=K.test(C.getElementsByClassName),d.getById=ce(function(e){return a.appendChild(e).id=k,!C.getElementsByName||!C.getElementsByName(k).length}),d.getById?(b.filter.ID=function(e){var t=e.replace(te,ne);return function(e){return e.getAttribute("id")===t}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&E){var n=t.getElementById(e);return n?[n]:[]}}):(b.filter.ID=function(e){var n=e.replace(te,ne);return function(e){var t="undefined"!=typeof e.getAttributeNode&&e.getAttributeNode("id");return t&&t.value===n}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&E){var n,r,i,o=t.getElementById(e);if(o){if((n=o.getAttributeNode("id"))&&n.value===e)return[o];i=t.getElementsByName(e),r=0;while(o=i[r++])if((n=o.getAttributeNode("id"))&&n.value===e)return[o]}return[]}}),b.find.TAG=d.getElementsByTagName?function(e,t){return"undefined"!=typeof t.getElementsByTagName?t.getElementsByTagName(e):d.qsa?t.querySelectorAll(e):void 0}:function(e,t){var n,r=[],i=0,o=t.getElementsByTagName(e);if("*"===e){while(n=o[i++])1===n.nodeType&&r.push(n);return r}return o},b.find.CLASS=d.getElementsByClassName&&function(e,t){if("undefined"!=typeof t.getElementsByClassName&&E)return t.getElementsByClassName(e)},s=[],v=[],(d.qsa=K.test(C.querySelectorAll))&&(ce(function(e){a.appendChild(e).innerHTML="<a id='"+k+"'></a><select id='"+k+"-\r\\' msallowcapture=''><option selected=''></option></select>",e.querySelectorAll("[msallowcapture^='']").length&&v.push("[*^$]="+M+"*(?:''|\"\")"),e.querySelectorAll("[selected]").length||v.push("\\["+M+"*(?:value|"+R+")"),e.querySelectorAll("[id~="+k+"-]").length||v.push("~="),e.querySelectorAll(":checked").length||v.push(":checked"),e.querySelectorAll("a#"+k+"+*").length||v.push(".#.+[+~]")}),ce(function(e){e.innerHTML="<a href='' disabled='disabled'></a><select disabled='disabled'><option/></select>";var t=C.createElement("input");t.setAttribute("type","hidden"),e.appendChild(t).setAttribute("name","D"),e.querySelectorAll("[name=d]").length&&v.push("name"+M+"*[*^$|!~]?="),2!==e.querySelectorAll(":enabled").length&&v.push(":enabled",":disabled"),a.appendChild(e).disabled=!0,2!==e.querySelectorAll(":disabled").length&&v.push(":enabled",":disabled"),e.querySelectorAll("*,:x"),v.push(",.*:")})),(d.matchesSelector=K.test(c=a.matches||a.webkitMatchesSelector||a.mozMatchesSelector||a.oMatchesSelector||a.msMatchesSelector))&&ce(function(e){d.disconnectedMatch=c.call(e,"*"),c.call(e,"[s!='']:x"),s.push("!=",$)}),v=v.length&&new RegExp(v.join("|")),s=s.length&&new RegExp(s.join("|")),t=K.test(a.compareDocumentPosition),y=t||K.test(a.contains)?function(e,t){var n=9===e.nodeType?e.documentElement:e,r=t&&t.parentNode;return e===r||!(!r||1!==r.nodeType||!(n.contains?n.contains(r):e.compareDocumentPosition&&16&e.compareDocumentPosition(r)))}:function(e,t){if(t)while(t=t.parentNode)if(t===e)return!0;return!1},D=t?function(e,t){if(e===t)return l=!0,0;var n=!e.compareDocumentPosition-!t.compareDocumentPosition;return n||(1&(n=(e.ownerDocument||e)===(t.ownerDocument||t)?e.compareDocumentPosition(t):1)||!d.sortDetached&&t.compareDocumentPosition(e)===n?e===C||e.ownerDocument===m&&y(m,e)?-1:t===C||t.ownerDocument===m&&y(m,t)?1:u?P(u,e)-P(u,t):0:4&n?-1:1)}:function(e,t){if(e===t)return l=!0,0;var n,r=0,i=e.parentNode,o=t.parentNode,a=[e],s=[t];if(!i||!o)return e===C?-1:t===C?1:i?-1:o?1:u?P(u,e)-P(u,t):0;if(i===o)return pe(e,t);n=e;while(n=n.parentNode)a.unshift(n);n=t;while(n=n.parentNode)s.unshift(n);while(a[r]===s[r])r++;return r?pe(a[r],s[r]):a[r]===m?-1:s[r]===m?1:0}),C},se.matches=function(e,t){return se(e,null,null,t)},se.matchesSelector=function(e,t){if((e.ownerDocument||e)!==C&&T(e),d.matchesSelector&&E&&!A[t+" "]&&(!s||!s.test(t))&&(!v||!v.test(t)))try{var n=c.call(e,t);if(n||d.disconnectedMatch||e.document&&11!==e.document.nodeType)return n}catch(e){A(t,!0)}return 0<se(t,C,null,[e]).length},se.contains=function(e,t){return(e.ownerDocument||e)!==C&&T(e),y(e,t)},se.attr=function(e,t){(e.ownerDocument||e)!==C&&T(e);var n=b.attrHandle[t.toLowerCase()],r=n&&j.call(b.attrHandle,t.toLowerCase())?n(e,t,!E):void 0;return void 0!==r?r:d.attributes||!E?e.getAttribute(t):(r=e.getAttributeNode(t))&&r.specified?r.value:null},se.escape=function(e){return(e+"").replace(re,ie)},se.error=function(e){throw new Error("Syntax error, unrecognized expression: "+e)},se.uniqueSort=function(e){var t,n=[],r=0,i=0;if(l=!d.detectDuplicates,u=!d.sortStable&&e.slice(0),e.sort(D),l){while(t=e[i++])t===e[i]&&(r=n.push(i));while(r--)e.splice(n[r],1)}return u=null,e},o=se.getText=function(e){var t,n="",r=0,i=e.nodeType;if(i){if(1===i||9===i||11===i){if("string"==typeof e.textContent)return e.textContent;for(e=e.firstChild;e;e=e.nextSibling)n+=o(e)}else if(3===i||4===i)return e.nodeValue}else while(t=e[r++])n+=o(t);return n},(b=se.selectors={cacheLength:50,createPseudo:le,match:G,attrHandle:{},find:{},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(e){return e[1]=e[1].replace(te,ne),e[3]=(e[3]||e[4]||e[5]||"").replace(te,ne),"~="===e[2]&&(e[3]=" "+e[3]+" "),e.slice(0,4)},CHILD:function(e){return e[1]=e[1].toLowerCase(),"nth"===e[1].slice(0,3)?(e[3]||se.error(e[0]),e[4]=+(e[4]?e[5]+(e[6]||1):2*("even"===e[3]||"odd"===e[3])),e[5]=+(e[7]+e[8]||"odd"===e[3])):e[3]&&se.error(e[0]),e},PSEUDO:function(e){var t,n=!e[6]&&e[2];return G.CHILD.test(e[0])?null:(e[3]?e[2]=e[4]||e[5]||"":n&&X.test(n)&&(t=h(n,!0))&&(t=n.indexOf(")",n.length-t)-n.length)&&(e[0]=e[0].slice(0,t),e[2]=n.slice(0,t)),e.slice(0,3))}},filter:{TAG:function(e){var t=e.replace(te,ne).toLowerCase();return"*"===e?function(){return!0}:function(e){return e.nodeName&&e.nodeName.toLowerCase()===t}},CLASS:function(e){var t=p[e+" "];return t||(t=new RegExp("(^|"+M+")"+e+"("+M+"|$)"))&&p(e,function(e){return t.test("string"==typeof e.className&&e.className||"undefined"!=typeof e.getAttribute&&e.getAttribute("class")||"")})},ATTR:function(n,r,i){return function(e){var t=se.attr(e,n);return null==t?"!="===r:!r||(t+="","="===r?t===i:"!="===r?t!==i:"^="===r?i&&0===t.indexOf(i):"*="===r?i&&-1<t.indexOf(i):"$="===r?i&&t.slice(-i.length)===i:"~="===r?-1<(" "+t.replace(F," ")+" ").indexOf(i):"|="===r&&(t===i||t.slice(0,i.length+1)===i+"-"))}},CHILD:function(h,e,t,g,v){var y="nth"!==h.slice(0,3),m="last"!==h.slice(-4),x="of-type"===e;return 1===g&&0===v?function(e){return!!e.parentNode}:function(e,t,n){var r,i,o,a,s,u,l=y!==m?"nextSibling":"previousSibling",c=e.parentNode,f=x&&e.nodeName.toLowerCase(),p=!n&&!x,d=!1;if(c){if(y){while(l){a=e;while(a=a[l])if(x?a.nodeName.toLowerCase()===f:1===a.nodeType)return!1;u=l="only"===h&&!u&&"nextSibling"}return!0}if(u=[m?c.firstChild:c.lastChild],m&&p){d=(s=(r=(i=(o=(a=c)[k]||(a[k]={}))[a.uniqueID]||(o[a.uniqueID]={}))[h]||[])[0]===S&&r[1])&&r[2],a=s&&c.childNodes[s];while(a=++s&&a&&a[l]||(d=s=0)||u.pop())if(1===a.nodeType&&++d&&a===e){i[h]=[S,s,d];break}}else if(p&&(d=s=(r=(i=(o=(a=e)[k]||(a[k]={}))[a.uniqueID]||(o[a.uniqueID]={}))[h]||[])[0]===S&&r[1]),!1===d)while(a=++s&&a&&a[l]||(d=s=0)||u.pop())if((x?a.nodeName.toLowerCase()===f:1===a.nodeType)&&++d&&(p&&((i=(o=a[k]||(a[k]={}))[a.uniqueID]||(o[a.uniqueID]={}))[h]=[S,d]),a===e))break;return(d-=v)===g||d%g==0&&0<=d/g}}},PSEUDO:function(e,o){var t,a=b.pseudos[e]||b.setFilters[e.toLowerCase()]||se.error("unsupported pseudo: "+e);return a[k]?a(o):1<a.length?(t=[e,e,"",o],b.setFilters.hasOwnProperty(e.toLowerCase())?le(function(e,t){var n,r=a(e,o),i=r.length;while(i--)e[n=P(e,r[i])]=!(t[n]=r[i])}):function(e){return a(e,0,t)}):a}},pseudos:{not:le(function(e){var r=[],i=[],s=f(e.replace(B,"$1"));return s[k]?le(function(e,t,n,r){var i,o=s(e,null,r,[]),a=e.length;while(a--)(i=o[a])&&(e[a]=!(t[a]=i))}):function(e,t,n){return r[0]=e,s(r,null,n,i),r[0]=null,!i.pop()}}),has:le(function(t){return function(e){return 0<se(t,e).length}}),contains:le(function(t){return t=t.replace(te,ne),function(e){return-1<(e.textContent||o(e)).indexOf(t)}}),lang:le(function(n){return V.test(n||"")||se.error("unsupported lang: "+n),n=n.replace(te,ne).toLowerCase(),function(e){var t;do{if(t=E?e.lang:e.getAttribute("xml:lang")||e.getAttribute("lang"))return(t=t.toLowerCase())===n||0===t.indexOf(n+"-")}while((e=e.parentNode)&&1===e.nodeType);return!1}}),target:function(e){var t=n.location&&n.location.hash;return t&&t.slice(1)===e.id},root:function(e){return e===a},focus:function(e){return e===C.activeElement&&(!C.hasFocus||C.hasFocus())&&!!(e.type||e.href||~e.tabIndex)},enabled:ge(!1),disabled:ge(!0),checked:function(e){var t=e.nodeName.toLowerCase();return"input"===t&&!!e.checked||"option"===t&&!!e.selected},selected:function(e){return e.parentNode&&e.parentNode.selectedIndex,!0===e.selected},empty:function(e){for(e=e.firstChild;e;e=e.nextSibling)if(e.nodeType<6)return!1;return!0},parent:function(e){return!b.pseudos.empty(e)},header:function(e){return J.test(e.nodeName)},input:function(e){return Q.test(e.nodeName)},button:function(e){var t=e.nodeName.toLowerCase();return"input"===t&&"button"===e.type||"button"===t},text:function(e){var t;return"input"===e.nodeName.toLowerCase()&&"text"===e.type&&(null==(t=e.getAttribute("type"))||"text"===t.toLowerCase())},first:ve(function(){return[0]}),last:ve(function(e,t){return[t-1]}),eq:ve(function(e,t,n){return[n<0?n+t:n]}),even:ve(function(e,t){for(var n=0;n<t;n+=2)e.push(n);return e}),odd:ve(function(e,t){for(var n=1;n<t;n+=2)e.push(n);return e}),lt:ve(function(e,t,n){for(var r=n<0?n+t:t<n?t:n;0<=--r;)e.push(r);return e}),gt:ve(function(e,t,n){for(var r=n<0?n+t:n;++r<t;)e.push(r);return e})}}).pseudos.nth=b.pseudos.eq,{radio:!0,checkbox:!0,file:!0,password:!0,image:!0})b.pseudos[e]=de(e);for(e in{submit:!0,reset:!0})b.pseudos[e]=he(e);function me(){}function xe(e){for(var t=0,n=e.length,r="";t<n;t++)r+=e[t].value;return r}function be(s,e,t){var u=e.dir,l=e.next,c=l||u,f=t&&"parentNode"===c,p=r++;return e.first?function(e,t,n){while(e=e[u])if(1===e.nodeType||f)return s(e,t,n);return!1}:function(e,t,n){var r,i,o,a=[S,p];if(n){while(e=e[u])if((1===e.nodeType||f)&&s(e,t,n))return!0}else while(e=e[u])if(1===e.nodeType||f)if(i=(o=e[k]||(e[k]={}))[e.uniqueID]||(o[e.uniqueID]={}),l&&l===e.nodeName.toLowerCase())e=e[u]||e;else{if((r=i[c])&&r[0]===S&&r[1]===p)return a[2]=r[2];if((i[c]=a)[2]=s(e,t,n))return!0}return!1}}function we(i){return 1<i.length?function(e,t,n){var r=i.length;while(r--)if(!i[r](e,t,n))return!1;return!0}:i[0]}function Te(e,t,n,r,i){for(var o,a=[],s=0,u=e.length,l=null!=t;s<u;s++)(o=e[s])&&(n&&!n(o,r,i)||(a.push(o),l&&t.push(s)));return a}function Ce(d,h,g,v,y,e){return v&&!v[k]&&(v=Ce(v)),y&&!y[k]&&(y=Ce(y,e)),le(function(e,t,n,r){var i,o,a,s=[],u=[],l=t.length,c=e||function(e,t,n){for(var r=0,i=t.length;r<i;r++)se(e,t[r],n);return n}(h||"*",n.nodeType?[n]:n,[]),f=!d||!e&&h?c:Te(c,s,d,n,r),p=g?y||(e?d:l||v)?[]:t:f;if(g&&g(f,p,n,r),v){i=Te(p,u),v(i,[],n,r),o=i.length;while(o--)(a=i[o])&&(p[u[o]]=!(f[u[o]]=a))}if(e){if(y||d){if(y){i=[],o=p.length;while(o--)(a=p[o])&&i.push(f[o]=a);y(null,p=[],i,r)}o=p.length;while(o--)(a=p[o])&&-1<(i=y?P(e,a):s[o])&&(e[i]=!(t[i]=a))}}else p=Te(p===t?p.splice(l,p.length):p),y?y(null,t,p,r):H.apply(t,p)})}function Ee(e){for(var i,t,n,r=e.length,o=b.relative[e[0].type],a=o||b.relative[" "],s=o?1:0,u=be(function(e){return e===i},a,!0),l=be(function(e){return-1<P(i,e)},a,!0),c=[function(e,t,n){var r=!o&&(n||t!==w)||((i=t).nodeType?u(e,t,n):l(e,t,n));return i=null,r}];s<r;s++)if(t=b.relative[e[s].type])c=[be(we(c),t)];else{if((t=b.filter[e[s].type].apply(null,e[s].matches))[k]){for(n=++s;n<r;n++)if(b.relative[e[n].type])break;return Ce(1<s&&we(c),1<s&&xe(e.slice(0,s-1).concat({value:" "===e[s-2].type?"*":""})).replace(B,"$1"),t,s<n&&Ee(e.slice(s,n)),n<r&&Ee(e=e.slice(n)),n<r&&xe(e))}c.push(t)}return we(c)}return me.prototype=b.filters=b.pseudos,b.setFilters=new me,h=se.tokenize=function(e,t){var n,r,i,o,a,s,u,l=x[e+" "];if(l)return t?0:l.slice(0);a=e,s=[],u=b.preFilter;while(a){for(o in n&&!(r=_.exec(a))||(r&&(a=a.slice(r[0].length)||a),s.push(i=[])),n=!1,(r=z.exec(a))&&(n=r.shift(),i.push({value:n,type:r[0].replace(B," ")}),a=a.slice(n.length)),b.filter)!(r=G[o].exec(a))||u[o]&&!(r=u[o](r))||(n=r.shift(),i.push({value:n,type:o,matches:r}),a=a.slice(n.length));if(!n)break}return t?a.length:a?se.error(e):x(e,s).slice(0)},f=se.compile=function(e,t){var n,v,y,m,x,r,i=[],o=[],a=N[e+" "];if(!a){t||(t=h(e)),n=t.length;while(n--)(a=Ee(t[n]))[k]?i.push(a):o.push(a);(a=N(e,(v=o,m=0<(y=i).length,x=0<v.length,r=function(e,t,n,r,i){var o,a,s,u=0,l="0",c=e&&[],f=[],p=w,d=e||x&&b.find.TAG("*",i),h=S+=null==p?1:Math.random()||.1,g=d.length;for(i&&(w=t===C||t||i);l!==g&&null!=(o=d[l]);l++){if(x&&o){a=0,t||o.ownerDocument===C||(T(o),n=!E);while(s=v[a++])if(s(o,t||C,n)){r.push(o);break}i&&(S=h)}m&&((o=!s&&o)&&u--,e&&c.push(o))}if(u+=l,m&&l!==u){a=0;while(s=y[a++])s(c,f,t,n);if(e){if(0<u)while(l--)c[l]||f[l]||(f[l]=q.call(r));f=Te(f)}H.apply(r,f),i&&!e&&0<f.length&&1<u+y.length&&se.uniqueSort(r)}return i&&(S=h,w=p),c},m?le(r):r))).selector=e}return a},g=se.select=function(e,t,n,r){var i,o,a,s,u,l="function"==typeof e&&e,c=!r&&h(e=l.selector||e);if(n=n||[],1===c.length){if(2<(o=c[0]=c[0].slice(0)).length&&"ID"===(a=o[0]).type&&9===t.nodeType&&E&&b.relative[o[1].type]){if(!(t=(b.find.ID(a.matches[0].replace(te,ne),t)||[])[0]))return n;l&&(t=t.parentNode),e=e.slice(o.shift().value.length)}i=G.needsContext.test(e)?0:o.length;while(i--){if(a=o[i],b.relative[s=a.type])break;if((u=b.find[s])&&(r=u(a.matches[0].replace(te,ne),ee.test(o[0].type)&&ye(t.parentNode)||t))){if(o.splice(i,1),!(e=r.length&&xe(o)))return H.apply(n,r),n;break}}}return(l||f(e,c))(r,t,!E,n,!t||ee.test(e)&&ye(t.parentNode)||t),n},d.sortStable=k.split("").sort(D).join("")===k,d.detectDuplicates=!!l,T(),d.sortDetached=ce(function(e){return 1&e.compareDocumentPosition(C.createElement("fieldset"))}),ce(function(e){return e.innerHTML="<a href='#'></a>","#"===e.firstChild.getAttribute("href")})||fe("type|href|height|width",function(e,t,n){if(!n)return e.getAttribute(t,"type"===t.toLowerCase()?1:2)}),d.attributes&&ce(function(e){return e.innerHTML="<input/>",e.firstChild.setAttribute("value",""),""===e.firstChild.getAttribute("value")})||fe("value",function(e,t,n){if(!n&&"input"===e.nodeName.toLowerCase())return e.defaultValue}),ce(function(e){return null==e.getAttribute("disabled")})||fe(R,function(e,t,n){var r;if(!n)return!0===e[t]?t.toLowerCase():(r=e.getAttributeNode(t))&&r.specified?r.value:null}),se}(C);k.find=h,k.expr=h.selectors,k.expr[":"]=k.expr.pseudos,k.uniqueSort=k.unique=h.uniqueSort,k.text=h.getText,k.isXMLDoc=h.isXML,k.contains=h.contains,k.escapeSelector=h.escape;var T=function(e,t,n){var r=[],i=void 0!==n;while((e=e[t])&&9!==e.nodeType)if(1===e.nodeType){if(i&&k(e).is(n))break;r.push(e)}return r},S=function(e,t){for(var n=[];e;e=e.nextSibling)1===e.nodeType&&e!==t&&n.push(e);return n},N=k.expr.match.needsContext;function A(e,t){return e.nodeName&&e.nodeName.toLowerCase()===t.toLowerCase()}var D=/^<([a-z][^\/\0>:\x20\t\r\n\f]*)[\x20\t\r\n\f]*\/?>(?:<\/\1>|)$/i;function j(e,n,r){return m(n)?k.grep(e,function(e,t){return!!n.call(e,t,e)!==r}):n.nodeType?k.grep(e,function(e){return e===n!==r}):"string"!=typeof n?k.grep(e,function(e){return-1<i.call(n,e)!==r}):k.filter(n,e,r)}k.filter=function(e,t,n){var r=t[0];return n&&(e=":not("+e+")"),1===t.length&&1===r.nodeType?k.find.matchesSelector(r,e)?[r]:[]:k.find.matches(e,k.grep(t,function(e){return 1===e.nodeType}))},k.fn.extend({find:function(e){var t,n,r=this.length,i=this;if("string"!=typeof e)return this.pushStack(k(e).filter(function(){for(t=0;t<r;t++)if(k.contains(i[t],this))return!0}));for(n=this.pushStack([]),t=0;t<r;t++)k.find(e,i[t],n);return 1<r?k.uniqueSort(n):n},filter:function(e){return this.pushStack(j(this,e||[],!1))},not:function(e){return this.pushStack(j(this,e||[],!0))},is:function(e){return!!j(this,"string"==typeof e&&N.test(e)?k(e):e||[],!1).length}});var q,L=/^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]+))$/;(k.fn.init=function(e,t,n){var r,i;if(!e)return this;if(n=n||q,"string"==typeof e){if(!(r="<"===e[0]&&">"===e[e.length-1]&&3<=e.length?[null,e,null]:L.exec(e))||!r[1]&&t)return!t||t.jquery?(t||n).find(e):this.constructor(t).find(e);if(r[1]){if(t=t instanceof k?t[0]:t,k.merge(this,k.parseHTML(r[1],t&&t.nodeType?t.ownerDocument||t:E,!0)),D.test(r[1])&&k.isPlainObject(t))for(r in t)m(this[r])?this[r](t[r]):this.attr(r,t[r]);return this}return(i=E.getElementById(r[2]))&&(this[0]=i,this.length=1),this}return e.nodeType?(this[0]=e,this.length=1,this):m(e)?void 0!==n.ready?n.ready(e):e(k):k.makeArray(e,this)}).prototype=k.fn,q=k(E);var H=/^(?:parents|prev(?:Until|All))/,O={children:!0,contents:!0,next:!0,prev:!0};function P(e,t){while((e=e[t])&&1!==e.nodeType);return e}k.fn.extend({has:function(e){var t=k(e,this),n=t.length;return this.filter(function(){for(var e=0;e<n;e++)if(k.contains(this,t[e]))return!0})},closest:function(e,t){var n,r=0,i=this.length,o=[],a="string"!=typeof e&&k(e);if(!N.test(e))for(;r<i;r++)for(n=this[r];n&&n!==t;n=n.parentNode)if(n.nodeType<11&&(a?-1<a.index(n):1===n.nodeType&&k.find.matchesSelector(n,e))){o.push(n);break}return this.pushStack(1<o.length?k.uniqueSort(o):o)},index:function(e){return e?"string"==typeof e?i.call(k(e),this[0]):i.call(this,e.jquery?e[0]:e):this[0]&&this[0].parentNode?this.first().prevAll().length:-1},add:function(e,t){return this.pushStack(k.uniqueSort(k.merge(this.get(),k(e,t))))},addBack:function(e){return this.add(null==e?this.prevObject:this.prevObject.filter(e))}}),k.each({parent:function(e){var t=e.parentNode;return t&&11!==t.nodeType?t:null},parents:function(e){return T(e,"parentNode")},parentsUntil:function(e,t,n){return T(e,"parentNode",n)},next:function(e){return P(e,"nextSibling")},prev:function(e){return P(e,"previousSibling")},nextAll:function(e){return T(e,"nextSibling")},prevAll:function(e){return T(e,"previousSibling")},nextUntil:function(e,t,n){return T(e,"nextSibling",n)},prevUntil:function(e,t,n){return T(e,"previousSibling",n)},siblings:function(e){return S((e.parentNode||{}).firstChild,e)},children:function(e){return S(e.firstChild)},contents:function(e){return"undefined"!=typeof e.contentDocument?e.contentDocument:(A(e,"template")&&(e=e.content||e),k.merge([],e.childNodes))}},function(r,i){k.fn[r]=function(e,t){var n=k.map(this,i,e);return"Until"!==r.slice(-5)&&(t=e),t&&"string"==typeof t&&(n=k.filter(t,n)),1<this.length&&(O[r]||k.uniqueSort(n),H.test(r)&&n.reverse()),this.pushStack(n)}});var R=/[^\x20\t\r\n\f]+/g;function M(e){return e}function I(e){throw e}function W(e,t,n,r){var i;try{e&&m(i=e.promise)?i.call(e).done(t).fail(n):e&&m(i=e.then)?i.call(e,t,n):t.apply(void 0,[e].slice(r))}catch(e){n.apply(void 0,[e])}}k.Callbacks=function(r){var e,n;r="string"==typeof r?(e=r,n={},k.each(e.match(R)||[],function(e,t){n[t]=!0}),n):k.extend({},r);var i,t,o,a,s=[],u=[],l=-1,c=function(){for(a=a||r.once,o=i=!0;u.length;l=-1){t=u.shift();while(++l<s.length)!1===s[l].apply(t[0],t[1])&&r.stopOnFalse&&(l=s.length,t=!1)}r.memory||(t=!1),i=!1,a&&(s=t?[]:"")},f={add:function(){return s&&(t&&!i&&(l=s.length-1,u.push(t)),function n(e){k.each(e,function(e,t){m(t)?r.unique&&f.has(t)||s.push(t):t&&t.length&&"string"!==w(t)&&n(t)})}(arguments),t&&!i&&c()),this},remove:function(){return k.each(arguments,function(e,t){var n;while(-1<(n=k.inArray(t,s,n)))s.splice(n,1),n<=l&&l--}),this},has:function(e){return e?-1<k.inArray(e,s):0<s.length},empty:function(){return s&&(s=[]),this},disable:function(){return a=u=[],s=t="",this},disabled:function(){return!s},lock:function(){return a=u=[],t||i||(s=t=""),this},locked:function(){return!!a},fireWith:function(e,t){return a||(t=[e,(t=t||[]).slice?t.slice():t],u.push(t),i||c()),this},fire:function(){return f.fireWith(this,arguments),this},fired:function(){return!!o}};return f},k.extend({Deferred:function(e){var o=[["notify","progress",k.Callbacks("memory"),k.Callbacks("memory"),2],["resolve","done",k.Callbacks("once memory"),k.Callbacks("once memory"),0,"resolved"],["reject","fail",k.Callbacks("once memory"),k.Callbacks("once memory"),1,"rejected"]],i="pending",a={state:function(){return i},always:function(){return s.done(arguments).fail(arguments),this},"catch":function(e){return a.then(null,e)},pipe:function(){var i=arguments;return k.Deferred(function(r){k.each(o,function(e,t){var n=m(i[t[4]])&&i[t[4]];s[t[1]](function(){var e=n&&n.apply(this,arguments);e&&m(e.promise)?e.promise().progress(r.notify).done(r.resolve).fail(r.reject):r[t[0]+"With"](this,n?[e]:arguments)})}),i=null}).promise()},then:function(t,n,r){var u=0;function l(i,o,a,s){return function(){var n=this,r=arguments,e=function(){var e,t;if(!(i<u)){if((e=a.apply(n,r))===o.promise())throw new TypeError("Thenable self-resolution");t=e&&("object"==typeof e||"function"==typeof e)&&e.then,m(t)?s?t.call(e,l(u,o,M,s),l(u,o,I,s)):(u++,t.call(e,l(u,o,M,s),l(u,o,I,s),l(u,o,M,o.notifyWith))):(a!==M&&(n=void 0,r=[e]),(s||o.resolveWith)(n,r))}},t=s?e:function(){try{e()}catch(e){k.Deferred.exceptionHook&&k.Deferred.exceptionHook(e,t.stackTrace),u<=i+1&&(a!==I&&(n=void 0,r=[e]),o.rejectWith(n,r))}};i?t():(k.Deferred.getStackHook&&(t.stackTrace=k.Deferred.getStackHook()),C.setTimeout(t))}}return k.Deferred(function(e){o[0][3].add(l(0,e,m(r)?r:M,e.notifyWith)),o[1][3].add(l(0,e,m(t)?t:M)),o[2][3].add(l(0,e,m(n)?n:I))}).promise()},promise:function(e){return null!=e?k.extend(e,a):a}},s={};return k.each(o,function(e,t){var n=t[2],r=t[5];a[t[1]]=n.add,r&&n.add(function(){i=r},o[3-e][2].disable,o[3-e][3].disable,o[0][2].lock,o[0][3].lock),n.add(t[3].fire),s[t[0]]=function(){return s[t[0]+"With"](this===s?void 0:this,arguments),this},s[t[0]+"With"]=n.fireWith}),a.promise(s),e&&e.call(s,s),s},when:function(e){var n=arguments.length,t=n,r=Array(t),i=s.call(arguments),o=k.Deferred(),a=function(t){return function(e){r[t]=this,i[t]=1<arguments.length?s.call(arguments):e,--n||o.resolveWith(r,i)}};if(n<=1&&(W(e,o.done(a(t)).resolve,o.reject,!n),"pending"===o.state()||m(i[t]&&i[t].then)))return o.then();while(t--)W(i[t],a(t),o.reject);return o.promise()}});var $=/^(Eval|Internal|Range|Reference|Syntax|Type|URI)Error$/;k.Deferred.exceptionHook=function(e,t){C.console&&C.console.warn&&e&&$.test(e.name)&&C.console.warn("jQuery.Deferred exception: "+e.message,e.stack,t)},k.readyException=function(e){C.setTimeout(function(){throw e})};var F=k.Deferred();function B(){E.removeEventListener("DOMContentLoaded",B),C.removeEventListener("load",B),k.ready()}k.fn.ready=function(e){return F.then(e)["catch"](function(e){k.readyException(e)}),this},k.extend({isReady:!1,readyWait:1,ready:function(e){(!0===e?--k.readyWait:k.isReady)||(k.isReady=!0)!==e&&0<--k.readyWait||F.resolveWith(E,[k])}}),k.ready.then=F.then,"complete"===E.readyState||"loading"!==E.readyState&&!E.documentElement.doScroll?C.setTimeout(k.ready):(E.addEventListener("DOMContentLoaded",B),C.addEventListener("load",B));var _=function(e,t,n,r,i,o,a){var s=0,u=e.length,l=null==n;if("object"===w(n))for(s in i=!0,n)_(e,t,s,n[s],!0,o,a);else if(void 0!==r&&(i=!0,m(r)||(a=!0),l&&(a?(t.call(e,r),t=null):(l=t,t=function(e,t,n){return l.call(k(e),n)})),t))for(;s<u;s++)t(e[s],n,a?r:r.call(e[s],s,t(e[s],n)));return i?e:l?t.call(e):u?t(e[0],n):o},z=/^-ms-/,U=/-([a-z])/g;function X(e,t){return t.toUpperCase()}function V(e){return e.replace(z,"ms-").replace(U,X)}var G=function(e){return 1===e.nodeType||9===e.nodeType||!+e.nodeType};function Y(){this.expando=k.expando+Y.uid++}Y.uid=1,Y.prototype={cache:function(e){var t=e[this.expando];return t||(t={},G(e)&&(e.nodeType?e[this.expando]=t:Object.defineProperty(e,this.expando,{value:t,configurable:!0}))),t},set:function(e,t,n){var r,i=this.cache(e);if("string"==typeof t)i[V(t)]=n;else for(r in t)i[V(r)]=t[r];return i},get:function(e,t){return void 0===t?this.cache(e):e[this.expando]&&e[this.expando][V(t)]},access:function(e,t,n){return void 0===t||t&&"string"==typeof t&&void 0===n?this.get(e,t):(this.set(e,t,n),void 0!==n?n:t)},remove:function(e,t){var n,r=e[this.expando];if(void 0!==r){if(void 0!==t){n=(t=Array.isArray(t)?t.map(V):(t=V(t))in r?[t]:t.match(R)||[]).length;while(n--)delete r[t[n]]}(void 0===t||k.isEmptyObject(r))&&(e.nodeType?e[this.expando]=void 0:delete e[this.expando])}},hasData:function(e){var t=e[this.expando];return void 0!==t&&!k.isEmptyObject(t)}};var Q=new Y,J=new Y,K=/^(?:\{[\w\W]*\}|\[[\w\W]*\])$/,Z=/[A-Z]/g;function ee(e,t,n){var r,i;if(void 0===n&&1===e.nodeType)if(r="data-"+t.replace(Z,"-$&").toLowerCase(),"string"==typeof(n=e.getAttribute(r))){try{n="true"===(i=n)||"false"!==i&&("null"===i?null:i===+i+""?+i:K.test(i)?JSON.parse(i):i)}catch(e){}J.set(e,t,n)}else n=void 0;return n}k.extend({hasData:function(e){return J.hasData(e)||Q.hasData(e)},data:function(e,t,n){return J.access(e,t,n)},removeData:function(e,t){J.remove(e,t)},_data:function(e,t,n){return Q.access(e,t,n)},_removeData:function(e,t){Q.remove(e,t)}}),k.fn.extend({data:function(n,e){var t,r,i,o=this[0],a=o&&o.attributes;if(void 0===n){if(this.length&&(i=J.get(o),1===o.nodeType&&!Q.get(o,"hasDataAttrs"))){t=a.length;while(t--)a[t]&&0===(r=a[t].name).indexOf("data-")&&(r=V(r.slice(5)),ee(o,r,i[r]));Q.set(o,"hasDataAttrs",!0)}return i}return"object"==typeof n?this.each(function(){J.set(this,n)}):_(this,function(e){var t;if(o&&void 0===e)return void 0!==(t=J.get(o,n))?t:void 0!==(t=ee(o,n))?t:void 0;this.each(function(){J.set(this,n,e)})},null,e,1<arguments.length,null,!0)},removeData:function(e){return this.each(function(){J.remove(this,e)})}}),k.extend({queue:function(e,t,n){var r;if(e)return t=(t||"fx")+"queue",r=Q.get(e,t),n&&(!r||Array.isArray(n)?r=Q.access(e,t,k.makeArray(n)):r.push(n)),r||[]},dequeue:function(e,t){t=t||"fx";var n=k.queue(e,t),r=n.length,i=n.shift(),o=k._queueHooks(e,t);"inprogress"===i&&(i=n.shift(),r--),i&&("fx"===t&&n.unshift("inprogress"),delete o.stop,i.call(e,function(){k.dequeue(e,t)},o)),!r&&o&&o.empty.fire()},_queueHooks:function(e,t){var n=t+"queueHooks";return Q.get(e,n)||Q.access(e,n,{empty:k.Callbacks("once memory").add(function(){Q.remove(e,[t+"queue",n])})})}}),k.fn.extend({queue:function(t,n){var e=2;return"string"!=typeof t&&(n=t,t="fx",e--),arguments.length<e?k.queue(this[0],t):void 0===n?this:this.each(function(){var e=k.queue(this,t,n);k._queueHooks(this,t),"fx"===t&&"inprogress"!==e[0]&&k.dequeue(this,t)})},dequeue:function(e){return this.each(function(){k.dequeue(this,e)})},clearQueue:function(e){return this.queue(e||"fx",[])},promise:function(e,t){var n,r=1,i=k.Deferred(),o=this,a=this.length,s=function(){--r||i.resolveWith(o,[o])};"string"!=typeof e&&(t=e,e=void 0),e=e||"fx";while(a--)(n=Q.get(o[a],e+"queueHooks"))&&n.empty&&(r++,n.empty.add(s));return s(),i.promise(t)}});var te=/[+-]?(?:\d*\.|)\d+(?:[eE][+-]?\d+|)/.source,ne=new RegExp("^(?:([+-])=|)("+te+")([a-z%]*)$","i"),re=["Top","Right","Bottom","Left"],ie=E.documentElement,oe=function(e){return k.contains(e.ownerDocument,e)},ae={composed:!0};ie.getRootNode&&(oe=function(e){return k.contains(e.ownerDocument,e)||e.getRootNode(ae)===e.ownerDocument});var se=function(e,t){return"none"===(e=t||e).style.display||""===e.style.display&&oe(e)&&"none"===k.css(e,"display")},ue=function(e,t,n,r){var i,o,a={};for(o in t)a[o]=e.style[o],e.style[o]=t[o];for(o in i=n.apply(e,r||[]),t)e.style[o]=a[o];return i};function le(e,t,n,r){var i,o,a=20,s=r?function(){return r.cur()}:function(){return k.css(e,t,"")},u=s(),l=n&&n[3]||(k.cssNumber[t]?"":"px"),c=e.nodeType&&(k.cssNumber[t]||"px"!==l&&+u)&&ne.exec(k.css(e,t));if(c&&c[3]!==l){u/=2,l=l||c[3],c=+u||1;while(a--)k.style(e,t,c+l),(1-o)*(1-(o=s()/u||.5))<=0&&(a=0),c/=o;c*=2,k.style(e,t,c+l),n=n||[]}return n&&(c=+c||+u||0,i=n[1]?c+(n[1]+1)*n[2]:+n[2],r&&(r.unit=l,r.start=c,r.end=i)),i}var ce={};function fe(e,t){for(var n,r,i,o,a,s,u,l=[],c=0,f=e.length;c<f;c++)(r=e[c]).style&&(n=r.style.display,t?("none"===n&&(l[c]=Q.get(r,"display")||null,l[c]||(r.style.display="")),""===r.style.display&&se(r)&&(l[c]=(u=a=o=void 0,a=(i=r).ownerDocument,s=i.nodeName,(u=ce[s])||(o=a.body.appendChild(a.createElement(s)),u=k.css(o,"display"),o.parentNode.removeChild(o),"none"===u&&(u="block"),ce[s]=u)))):"none"!==n&&(l[c]="none",Q.set(r,"display",n)));for(c=0;c<f;c++)null!=l[c]&&(e[c].style.display=l[c]);return e}k.fn.extend({show:function(){return fe(this,!0)},hide:function(){return fe(this)},toggle:function(e){return"boolean"==typeof e?e?this.show():this.hide():this.each(function(){se(this)?k(this).show():k(this).hide()})}});var pe=/^(?:checkbox|radio)$/i,de=/<([a-z][^\/\0>\x20\t\r\n\f]*)/i,he=/^$|^module$|\/(?:java|ecma)script/i,ge={option:[1,"<select multiple='multiple'>","</select>"],thead:[1,"<table>","</table>"],col:[2,"<table><colgroup>","</colgroup></table>"],tr:[2,"<table><tbody>","</tbody></table>"],td:[3,"<table><tbody><tr>","</tr></tbody></table>"],_default:[0,"",""]};function ve(e,t){var n;return n="undefined"!=typeof e.getElementsByTagName?e.getElementsByTagName(t||"*"):"undefined"!=typeof e.querySelectorAll?e.querySelectorAll(t||"*"):[],void 0===t||t&&A(e,t)?k.merge([e],n):n}function ye(e,t){for(var n=0,r=e.length;n<r;n++)Q.set(e[n],"globalEval",!t||Q.get(t[n],"globalEval"))}ge.optgroup=ge.option,ge.tbody=ge.tfoot=ge.colgroup=ge.caption=ge.thead,ge.th=ge.td;var me,xe,be=/<|&#?\w+;/;function we(e,t,n,r,i){for(var o,a,s,u,l,c,f=t.createDocumentFragment(),p=[],d=0,h=e.length;d<h;d++)if((o=e[d])||0===o)if("object"===w(o))k.merge(p,o.nodeType?[o]:o);else if(be.test(o)){a=a||f.appendChild(t.createElement("div")),s=(de.exec(o)||["",""])[1].toLowerCase(),u=ge[s]||ge._default,a.innerHTML=u[1]+k.htmlPrefilter(o)+u[2],c=u[0];while(c--)a=a.lastChild;k.merge(p,a.childNodes),(a=f.firstChild).textContent=""}else p.push(t.createTextNode(o));f.textContent="",d=0;while(o=p[d++])if(r&&-1<k.inArray(o,r))i&&i.push(o);else if(l=oe(o),a=ve(f.appendChild(o),"script"),l&&ye(a),n){c=0;while(o=a[c++])he.test(o.type||"")&&n.push(o)}return f}me=E.createDocumentFragment().appendChild(E.createElement("div")),(xe=E.createElement("input")).setAttribute("type","radio"),xe.setAttribute("checked","checked"),xe.setAttribute("name","t"),me.appendChild(xe),y.checkClone=me.cloneNode(!0).cloneNode(!0).lastChild.checked,me.innerHTML="<textarea>x</textarea>",y.noCloneChecked=!!me.cloneNode(!0).lastChild.defaultValue;var Te=/^key/,Ce=/^(?:mouse|pointer|contextmenu|drag|drop)|click/,Ee=/^([^.]*)(?:\.(.+)|)/;function ke(){return!0}function Se(){return!1}function Ne(e,t){return e===function(){try{return E.activeElement}catch(e){}}()==("focus"===t)}function Ae(e,t,n,r,i,o){var a,s;if("object"==typeof t){for(s in"string"!=typeof n&&(r=r||n,n=void 0),t)Ae(e,s,n,r,t[s],o);return e}if(null==r&&null==i?(i=n,r=n=void 0):null==i&&("string"==typeof n?(i=r,r=void 0):(i=r,r=n,n=void 0)),!1===i)i=Se;else if(!i)return e;return 1===o&&(a=i,(i=function(e){return k().off(e),a.apply(this,arguments)}).guid=a.guid||(a.guid=k.guid++)),e.each(function(){k.event.add(this,t,i,r,n)})}function De(e,i,o){o?(Q.set(e,i,!1),k.event.add(e,i,{namespace:!1,handler:function(e){var t,n,r=Q.get(this,i);if(1&e.isTrigger&&this[i]){if(r.length)(k.event.special[i]||{}).delegateType&&e.stopPropagation();else if(r=s.call(arguments),Q.set(this,i,r),t=o(this,i),this[i](),r!==(n=Q.get(this,i))||t?Q.set(this,i,!1):n={},r!==n)return e.stopImmediatePropagation(),e.preventDefault(),n.value}else r.length&&(Q.set(this,i,{value:k.event.trigger(k.extend(r[0],k.Event.prototype),r.slice(1),this)}),e.stopImmediatePropagation())}})):void 0===Q.get(e,i)&&k.event.add(e,i,ke)}k.event={global:{},add:function(t,e,n,r,i){var o,a,s,u,l,c,f,p,d,h,g,v=Q.get(t);if(v){n.handler&&(n=(o=n).handler,i=o.selector),i&&k.find.matchesSelector(ie,i),n.guid||(n.guid=k.guid++),(u=v.events)||(u=v.events={}),(a=v.handle)||(a=v.handle=function(e){return"undefined"!=typeof k&&k.event.triggered!==e.type?k.event.dispatch.apply(t,arguments):void 0}),l=(e=(e||"").match(R)||[""]).length;while(l--)d=g=(s=Ee.exec(e[l])||[])[1],h=(s[2]||"").split(".").sort(),d&&(f=k.event.special[d]||{},d=(i?f.delegateType:f.bindType)||d,f=k.event.special[d]||{},c=k.extend({type:d,origType:g,data:r,handler:n,guid:n.guid,selector:i,needsContext:i&&k.expr.match.needsContext.test(i),namespace:h.join(".")},o),(p=u[d])||((p=u[d]=[]).delegateCount=0,f.setup&&!1!==f.setup.call(t,r,h,a)||t.addEventListener&&t.addEventListener(d,a)),f.add&&(f.add.call(t,c),c.handler.guid||(c.handler.guid=n.guid)),i?p.splice(p.delegateCount++,0,c):p.push(c),k.event.global[d]=!0)}},remove:function(e,t,n,r,i){var o,a,s,u,l,c,f,p,d,h,g,v=Q.hasData(e)&&Q.get(e);if(v&&(u=v.events)){l=(t=(t||"").match(R)||[""]).length;while(l--)if(d=g=(s=Ee.exec(t[l])||[])[1],h=(s[2]||"").split(".").sort(),d){f=k.event.special[d]||{},p=u[d=(r?f.delegateType:f.bindType)||d]||[],s=s[2]&&new RegExp("(^|\\.)"+h.join("\\.(?:.*\\.|)")+"(\\.|$)"),a=o=p.length;while(o--)c=p[o],!i&&g!==c.origType||n&&n.guid!==c.guid||s&&!s.test(c.namespace)||r&&r!==c.selector&&("**"!==r||!c.selector)||(p.splice(o,1),c.selector&&p.delegateCount--,f.remove&&f.remove.call(e,c));a&&!p.length&&(f.teardown&&!1!==f.teardown.call(e,h,v.handle)||k.removeEvent(e,d,v.handle),delete u[d])}else for(d in u)k.event.remove(e,d+t[l],n,r,!0);k.isEmptyObject(u)&&Q.remove(e,"handle events")}},dispatch:function(e){var t,n,r,i,o,a,s=k.event.fix(e),u=new Array(arguments.length),l=(Q.get(this,"events")||{})[s.type]||[],c=k.event.special[s.type]||{};for(u[0]=s,t=1;t<arguments.length;t++)u[t]=arguments[t];if(s.delegateTarget=this,!c.preDispatch||!1!==c.preDispatch.call(this,s)){a=k.event.handlers.call(this,s,l),t=0;while((i=a[t++])&&!s.isPropagationStopped()){s.currentTarget=i.elem,n=0;while((o=i.handlers[n++])&&!s.isImmediatePropagationStopped())s.rnamespace&&!1!==o.namespace&&!s.rnamespace.test(o.namespace)||(s.handleObj=o,s.data=o.data,void 0!==(r=((k.event.special[o.origType]||{}).handle||o.handler).apply(i.elem,u))&&!1===(s.result=r)&&(s.preventDefault(),s.stopPropagation()))}return c.postDispatch&&c.postDispatch.call(this,s),s.result}},handlers:function(e,t){var n,r,i,o,a,s=[],u=t.delegateCount,l=e.target;if(u&&l.nodeType&&!("click"===e.type&&1<=e.button))for(;l!==this;l=l.parentNode||this)if(1===l.nodeType&&("click"!==e.type||!0!==l.disabled)){for(o=[],a={},n=0;n<u;n++)void 0===a[i=(r=t[n]).selector+" "]&&(a[i]=r.needsContext?-1<k(i,this).index(l):k.find(i,this,null,[l]).length),a[i]&&o.push(r);o.length&&s.push({elem:l,handlers:o})}return l=this,u<t.length&&s.push({elem:l,handlers:t.slice(u)}),s},addProp:function(t,e){Object.defineProperty(k.Event.prototype,t,{enumerable:!0,configurable:!0,get:m(e)?function(){if(this.originalEvent)return e(this.originalEvent)}:function(){if(this.originalEvent)return this.originalEvent[t]},set:function(e){Object.defineProperty(this,t,{enumerable:!0,configurable:!0,writable:!0,value:e})}})},fix:function(e){return e[k.expando]?e:new k.Event(e)},special:{load:{noBubble:!0},click:{setup:function(e){var t=this||e;return pe.test(t.type)&&t.click&&A(t,"input")&&De(t,"click",ke),!1},trigger:function(e){var t=this||e;return pe.test(t.type)&&t.click&&A(t,"input")&&De(t,"click"),!0},_default:function(e){var t=e.target;return pe.test(t.type)&&t.click&&A(t,"input")&&Q.get(t,"click")||A(t,"a")}},beforeunload:{postDispatch:function(e){void 0!==e.result&&e.originalEvent&&(e.originalEvent.returnValue=e.result)}}}},k.removeEvent=function(e,t,n){e.removeEventListener&&e.removeEventListener(t,n)},k.Event=function(e,t){if(!(this instanceof k.Event))return new k.Event(e,t);e&&e.type?(this.originalEvent=e,this.type=e.type,this.isDefaultPrevented=e.defaultPrevented||void 0===e.defaultPrevented&&!1===e.returnValue?ke:Se,this.target=e.target&&3===e.target.nodeType?e.target.parentNode:e.target,this.currentTarget=e.currentTarget,this.relatedTarget=e.relatedTarget):this.type=e,t&&k.extend(this,t),this.timeStamp=e&&e.timeStamp||Date.now(),this[k.expando]=!0},k.Event.prototype={constructor:k.Event,isDefaultPrevented:Se,isPropagationStopped:Se,isImmediatePropagationStopped:Se,isSimulated:!1,preventDefault:function(){var e=this.originalEvent;this.isDefaultPrevented=ke,e&&!this.isSimulated&&e.preventDefault()},stopPropagation:function(){var e=this.originalEvent;this.isPropagationStopped=ke,e&&!this.isSimulated&&e.stopPropagation()},stopImmediatePropagation:function(){var e=this.originalEvent;this.isImmediatePropagationStopped=ke,e&&!this.isSimulated&&e.stopImmediatePropagation(),this.stopPropagation()}},k.each({altKey:!0,bubbles:!0,cancelable:!0,changedTouches:!0,ctrlKey:!0,detail:!0,eventPhase:!0,metaKey:!0,pageX:!0,pageY:!0,shiftKey:!0,view:!0,"char":!0,code:!0,charCode:!0,key:!0,keyCode:!0,button:!0,buttons:!0,clientX:!0,clientY:!0,offsetX:!0,offsetY:!0,pointerId:!0,pointerType:!0,screenX:!0,screenY:!0,targetTouches:!0,toElement:!0,touches:!0,which:function(e){var t=e.button;return null==e.which&&Te.test(e.type)?null!=e.charCode?e.charCode:e.keyCode:!e.which&&void 0!==t&&Ce.test(e.type)?1&t?1:2&t?3:4&t?2:0:e.which}},k.event.addProp),k.each({focus:"focusin",blur:"focusout"},function(e,t){k.event.special[e]={setup:function(){return De(this,e,Ne),!1},trigger:function(){return De(this,e),!0},delegateType:t}}),k.each({mouseenter:"mouseover",mouseleave:"mouseout",pointerenter:"pointerover",pointerleave:"pointerout"},function(e,i){k.event.special[e]={delegateType:i,bindType:i,handle:function(e){var t,n=e.relatedTarget,r=e.handleObj;return n&&(n===this||k.contains(this,n))||(e.type=r.origType,t=r.handler.apply(this,arguments),e.type=i),t}}}),k.fn.extend({on:function(e,t,n,r){return Ae(this,e,t,n,r)},one:function(e,t,n,r){return Ae(this,e,t,n,r,1)},off:function(e,t,n){var r,i;if(e&&e.preventDefault&&e.handleObj)return r=e.handleObj,k(e.delegateTarget).off(r.namespace?r.origType+"."+r.namespace:r.origType,r.selector,r.handler),this;if("object"==typeof e){for(i in e)this.off(i,t,e[i]);return this}return!1!==t&&"function"!=typeof t||(n=t,t=void 0),!1===n&&(n=Se),this.each(function(){k.event.remove(this,e,n,t)})}});var je=/<(?!area|br|col|embed|hr|img|input|link|meta|param)(([a-z][^\/\0>\x20\t\r\n\f]*)[^>]*)\/>/gi,qe=/<script|<style|<link/i,Le=/checked\s*(?:[^=]|=\s*.checked.)/i,He=/^\s*<!(?:\[CDATA\[|--)|(?:\]\]|--)>\s*$/g;function Oe(e,t){return A(e,"table")&&A(11!==t.nodeType?t:t.firstChild,"tr")&&k(e).children("tbody")[0]||e}function Pe(e){return e.type=(null!==e.getAttribute("type"))+"/"+e.type,e}function Re(e){return"true/"===(e.type||"").slice(0,5)?e.type=e.type.slice(5):e.removeAttribute("type"),e}function Me(e,t){var n,r,i,o,a,s,u,l;if(1===t.nodeType){if(Q.hasData(e)&&(o=Q.access(e),a=Q.set(t,o),l=o.events))for(i in delete a.handle,a.events={},l)for(n=0,r=l[i].length;n<r;n++)k.event.add(t,i,l[i][n]);J.hasData(e)&&(s=J.access(e),u=k.extend({},s),J.set(t,u))}}function Ie(n,r,i,o){r=g.apply([],r);var e,t,a,s,u,l,c=0,f=n.length,p=f-1,d=r[0],h=m(d);if(h||1<f&&"string"==typeof d&&!y.checkClone&&Le.test(d))return n.each(function(e){var t=n.eq(e);h&&(r[0]=d.call(this,e,t.html())),Ie(t,r,i,o)});if(f&&(t=(e=we(r,n[0].ownerDocument,!1,n,o)).firstChild,1===e.childNodes.length&&(e=t),t||o)){for(s=(a=k.map(ve(e,"script"),Pe)).length;c<f;c++)u=e,c!==p&&(u=k.clone(u,!0,!0),s&&k.merge(a,ve(u,"script"))),i.call(n[c],u,c);if(s)for(l=a[a.length-1].ownerDocument,k.map(a,Re),c=0;c<s;c++)u=a[c],he.test(u.type||"")&&!Q.access(u,"globalEval")&&k.contains(l,u)&&(u.src&&"module"!==(u.type||"").toLowerCase()?k._evalUrl&&!u.noModule&&k._evalUrl(u.src,{nonce:u.nonce||u.getAttribute("nonce")}):b(u.textContent.replace(He,""),u,l))}return n}function We(e,t,n){for(var r,i=t?k.filter(t,e):e,o=0;null!=(r=i[o]);o++)n||1!==r.nodeType||k.cleanData(ve(r)),r.parentNode&&(n&&oe(r)&&ye(ve(r,"script")),r.parentNode.removeChild(r));return e}k.extend({htmlPrefilter:function(e){return e.replace(je,"<$1></$2>")},clone:function(e,t,n){var r,i,o,a,s,u,l,c=e.cloneNode(!0),f=oe(e);if(!(y.noCloneChecked||1!==e.nodeType&&11!==e.nodeType||k.isXMLDoc(e)))for(a=ve(c),r=0,i=(o=ve(e)).length;r<i;r++)s=o[r],u=a[r],void 0,"input"===(l=u.nodeName.toLowerCase())&&pe.test(s.type)?u.checked=s.checked:"input"!==l&&"textarea"!==l||(u.defaultValue=s.defaultValue);if(t)if(n)for(o=o||ve(e),a=a||ve(c),r=0,i=o.length;r<i;r++)Me(o[r],a[r]);else Me(e,c);return 0<(a=ve(c,"script")).length&&ye(a,!f&&ve(e,"script")),c},cleanData:function(e){for(var t,n,r,i=k.event.special,o=0;void 0!==(n=e[o]);o++)if(G(n)){if(t=n[Q.expando]){if(t.events)for(r in t.events)i[r]?k.event.remove(n,r):k.removeEvent(n,r,t.handle);n[Q.expando]=void 0}n[J.expando]&&(n[J.expando]=void 0)}}}),k.fn.extend({detach:function(e){return We(this,e,!0)},remove:function(e){return We(this,e)},text:function(e){return _(this,function(e){return void 0===e?k.text(this):this.empty().each(function(){1!==this.nodeType&&11!==this.nodeType&&9!==this.nodeType||(this.textContent=e)})},null,e,arguments.length)},append:function(){return Ie(this,arguments,function(e){1!==this.nodeType&&11!==this.nodeType&&9!==this.nodeType||Oe(this,e).appendChild(e)})},prepend:function(){return Ie(this,arguments,function(e){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var t=Oe(this,e);t.insertBefore(e,t.firstChild)}})},before:function(){return Ie(this,arguments,function(e){this.parentNode&&this.parentNode.insertBefore(e,this)})},after:function(){return Ie(this,arguments,function(e){this.parentNode&&this.parentNode.insertBefore(e,this.nextSibling)})},empty:function(){for(var e,t=0;null!=(e=this[t]);t++)1===e.nodeType&&(k.cleanData(ve(e,!1)),e.textContent="");return this},clone:function(e,t){return e=null!=e&&e,t=null==t?e:t,this.map(function(){return k.clone(this,e,t)})},html:function(e){return _(this,function(e){var t=this[0]||{},n=0,r=this.length;if(void 0===e&&1===t.nodeType)return t.innerHTML;if("string"==typeof e&&!qe.test(e)&&!ge[(de.exec(e)||["",""])[1].toLowerCase()]){e=k.htmlPrefilter(e);try{for(;n<r;n++)1===(t=this[n]||{}).nodeType&&(k.cleanData(ve(t,!1)),t.innerHTML=e);t=0}catch(e){}}t&&this.empty().append(e)},null,e,arguments.length)},replaceWith:function(){var n=[];return Ie(this,arguments,function(e){var t=this.parentNode;k.inArray(this,n)<0&&(k.cleanData(ve(this)),t&&t.replaceChild(e,this))},n)}}),k.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(e,a){k.fn[e]=function(e){for(var t,n=[],r=k(e),i=r.length-1,o=0;o<=i;o++)t=o===i?this:this.clone(!0),k(r[o])[a](t),u.apply(n,t.get());return this.pushStack(n)}});var $e=new RegExp("^("+te+")(?!px)[a-z%]+$","i"),Fe=function(e){var t=e.ownerDocument.defaultView;return t&&t.opener||(t=C),t.getComputedStyle(e)},Be=new RegExp(re.join("|"),"i");function _e(e,t,n){var r,i,o,a,s=e.style;return(n=n||Fe(e))&&(""!==(a=n.getPropertyValue(t)||n[t])||oe(e)||(a=k.style(e,t)),!y.pixelBoxStyles()&&$e.test(a)&&Be.test(t)&&(r=s.width,i=s.minWidth,o=s.maxWidth,s.minWidth=s.maxWidth=s.width=a,a=n.width,s.width=r,s.minWidth=i,s.maxWidth=o)),void 0!==a?a+"":a}function ze(e,t){return{get:function(){if(!e())return(this.get=t).apply(this,arguments);delete this.get}}}!function(){function e(){if(u){s.style.cssText="position:absolute;left:-11111px;width:60px;margin-top:1px;padding:0;border:0",u.style.cssText="position:relative;display:block;box-sizing:border-box;overflow:scroll;margin:auto;border:1px;padding:1px;width:60%;top:1%",ie.appendChild(s).appendChild(u);var e=C.getComputedStyle(u);n="1%"!==e.top,a=12===t(e.marginLeft),u.style.right="60%",o=36===t(e.right),r=36===t(e.width),u.style.position="absolute",i=12===t(u.offsetWidth/3),ie.removeChild(s),u=null}}function t(e){return Math.round(parseFloat(e))}var n,r,i,o,a,s=E.createElement("div"),u=E.createElement("div");u.style&&(u.style.backgroundClip="content-box",u.cloneNode(!0).style.backgroundClip="",y.clearCloneStyle="content-box"===u.style.backgroundClip,k.extend(y,{boxSizingReliable:function(){return e(),r},pixelBoxStyles:function(){return e(),o},pixelPosition:function(){return e(),n},reliableMarginLeft:function(){return e(),a},scrollboxSize:function(){return e(),i}}))}();var Ue=["Webkit","Moz","ms"],Xe=E.createElement("div").style,Ve={};function Ge(e){var t=k.cssProps[e]||Ve[e];return t||(e in Xe?e:Ve[e]=function(e){var t=e[0].toUpperCase()+e.slice(1),n=Ue.length;while(n--)if((e=Ue[n]+t)in Xe)return e}(e)||e)}var Ye=/^(none|table(?!-c[ea]).+)/,Qe=/^--/,Je={position:"absolute",visibility:"hidden",display:"block"},Ke={letterSpacing:"0",fontWeight:"400"};function Ze(e,t,n){var r=ne.exec(t);return r?Math.max(0,r[2]-(n||0))+(r[3]||"px"):t}function et(e,t,n,r,i,o){var a="width"===t?1:0,s=0,u=0;if(n===(r?"border":"content"))return 0;for(;a<4;a+=2)"margin"===n&&(u+=k.css(e,n+re[a],!0,i)),r?("content"===n&&(u-=k.css(e,"padding"+re[a],!0,i)),"margin"!==n&&(u-=k.css(e,"border"+re[a]+"Width",!0,i))):(u+=k.css(e,"padding"+re[a],!0,i),"padding"!==n?u+=k.css(e,"border"+re[a]+"Width",!0,i):s+=k.css(e,"border"+re[a]+"Width",!0,i));return!r&&0<=o&&(u+=Math.max(0,Math.ceil(e["offset"+t[0].toUpperCase()+t.slice(1)]-o-u-s-.5))||0),u}function tt(e,t,n){var r=Fe(e),i=(!y.boxSizingReliable()||n)&&"border-box"===k.css(e,"boxSizing",!1,r),o=i,a=_e(e,t,r),s="offset"+t[0].toUpperCase()+t.slice(1);if($e.test(a)){if(!n)return a;a="auto"}return(!y.boxSizingReliable()&&i||"auto"===a||!parseFloat(a)&&"inline"===k.css(e,"display",!1,r))&&e.getClientRects().length&&(i="border-box"===k.css(e,"boxSizing",!1,r),(o=s in e)&&(a=e[s])),(a=parseFloat(a)||0)+et(e,t,n||(i?"border":"content"),o,r,a)+"px"}function nt(e,t,n,r,i){return new nt.prototype.init(e,t,n,r,i)}k.extend({cssHooks:{opacity:{get:function(e,t){if(t){var n=_e(e,"opacity");return""===n?"1":n}}}},cssNumber:{animationIterationCount:!0,columnCount:!0,fillOpacity:!0,flexGrow:!0,flexShrink:!0,fontWeight:!0,gridArea:!0,gridColumn:!0,gridColumnEnd:!0,gridColumnStart:!0,gridRow:!0,gridRowEnd:!0,gridRowStart:!0,lineHeight:!0,opacity:!0,order:!0,orphans:!0,widows:!0,zIndex:!0,zoom:!0},cssProps:{},style:function(e,t,n,r){if(e&&3!==e.nodeType&&8!==e.nodeType&&e.style){var i,o,a,s=V(t),u=Qe.test(t),l=e.style;if(u||(t=Ge(s)),a=k.cssHooks[t]||k.cssHooks[s],void 0===n)return a&&"get"in a&&void 0!==(i=a.get(e,!1,r))?i:l[t];"string"===(o=typeof n)&&(i=ne.exec(n))&&i[1]&&(n=le(e,t,i),o="number"),null!=n&&n==n&&("number"!==o||u||(n+=i&&i[3]||(k.cssNumber[s]?"":"px")),y.clearCloneStyle||""!==n||0!==t.indexOf("background")||(l[t]="inherit"),a&&"set"in a&&void 0===(n=a.set(e,n,r))||(u?l.setProperty(t,n):l[t]=n))}},css:function(e,t,n,r){var i,o,a,s=V(t);return Qe.test(t)||(t=Ge(s)),(a=k.cssHooks[t]||k.cssHooks[s])&&"get"in a&&(i=a.get(e,!0,n)),void 0===i&&(i=_e(e,t,r)),"normal"===i&&t in Ke&&(i=Ke[t]),""===n||n?(o=parseFloat(i),!0===n||isFinite(o)?o||0:i):i}}),k.each(["height","width"],function(e,u){k.cssHooks[u]={get:function(e,t,n){if(t)return!Ye.test(k.css(e,"display"))||e.getClientRects().length&&e.getBoundingClientRect().width?tt(e,u,n):ue(e,Je,function(){return tt(e,u,n)})},set:function(e,t,n){var r,i=Fe(e),o=!y.scrollboxSize()&&"absolute"===i.position,a=(o||n)&&"border-box"===k.css(e,"boxSizing",!1,i),s=n?et(e,u,n,a,i):0;return a&&o&&(s-=Math.ceil(e["offset"+u[0].toUpperCase()+u.slice(1)]-parseFloat(i[u])-et(e,u,"border",!1,i)-.5)),s&&(r=ne.exec(t))&&"px"!==(r[3]||"px")&&(e.style[u]=t,t=k.css(e,u)),Ze(0,t,s)}}}),k.cssHooks.marginLeft=ze(y.reliableMarginLeft,function(e,t){if(t)return(parseFloat(_e(e,"marginLeft"))||e.getBoundingClientRect().left-ue(e,{marginLeft:0},function(){return e.getBoundingClientRect().left}))+"px"}),k.each({margin:"",padding:"",border:"Width"},function(i,o){k.cssHooks[i+o]={expand:function(e){for(var t=0,n={},r="string"==typeof e?e.split(" "):[e];t<4;t++)n[i+re[t]+o]=r[t]||r[t-2]||r[0];return n}},"margin"!==i&&(k.cssHooks[i+o].set=Ze)}),k.fn.extend({css:function(e,t){return _(this,function(e,t,n){var r,i,o={},a=0;if(Array.isArray(t)){for(r=Fe(e),i=t.length;a<i;a++)o[t[a]]=k.css(e,t[a],!1,r);return o}return void 0!==n?k.style(e,t,n):k.css(e,t)},e,t,1<arguments.length)}}),((k.Tween=nt).prototype={constructor:nt,init:function(e,t,n,r,i,o){this.elem=e,this.prop=n,this.easing=i||k.easing._default,this.options=t,this.start=this.now=this.cur(),this.end=r,this.unit=o||(k.cssNumber[n]?"":"px")},cur:function(){var e=nt.propHooks[this.prop];return e&&e.get?e.get(this):nt.propHooks._default.get(this)},run:function(e){var t,n=nt.propHooks[this.prop];return this.options.duration?this.pos=t=k.easing[this.easing](e,this.options.duration*e,0,1,this.options.duration):this.pos=t=e,this.now=(this.end-this.start)*t+this.start,this.options.step&&this.options.step.call(this.elem,this.now,this),n&&n.set?n.set(this):nt.propHooks._default.set(this),this}}).init.prototype=nt.prototype,(nt.propHooks={_default:{get:function(e){var t;return 1!==e.elem.nodeType||null!=e.elem[e.prop]&&null==e.elem.style[e.prop]?e.elem[e.prop]:(t=k.css(e.elem,e.prop,""))&&"auto"!==t?t:0},set:function(e){k.fx.step[e.prop]?k.fx.step[e.prop](e):1!==e.elem.nodeType||!k.cssHooks[e.prop]&&null==e.elem.style[Ge(e.prop)]?e.elem[e.prop]=e.now:k.style(e.elem,e.prop,e.now+e.unit)}}}).scrollTop=nt.propHooks.scrollLeft={set:function(e){e.elem.nodeType&&e.elem.parentNode&&(e.elem[e.prop]=e.now)}},k.easing={linear:function(e){return e},swing:function(e){return.5-Math.cos(e*Math.PI)/2},_default:"swing"},k.fx=nt.prototype.init,k.fx.step={};var rt,it,ot,at,st=/^(?:toggle|show|hide)$/,ut=/queueHooks$/;function lt(){it&&(!1===E.hidden&&C.requestAnimationFrame?C.requestAnimationFrame(lt):C.setTimeout(lt,k.fx.interval),k.fx.tick())}function ct(){return C.setTimeout(function(){rt=void 0}),rt=Date.now()}function ft(e,t){var n,r=0,i={height:e};for(t=t?1:0;r<4;r+=2-t)i["margin"+(n=re[r])]=i["padding"+n]=e;return t&&(i.opacity=i.width=e),i}function pt(e,t,n){for(var r,i=(dt.tweeners[t]||[]).concat(dt.tweeners["*"]),o=0,a=i.length;o<a;o++)if(r=i[o].call(n,t,e))return r}function dt(o,e,t){var n,a,r=0,i=dt.prefilters.length,s=k.Deferred().always(function(){delete u.elem}),u=function(){if(a)return!1;for(var e=rt||ct(),t=Math.max(0,l.startTime+l.duration-e),n=1-(t/l.duration||0),r=0,i=l.tweens.length;r<i;r++)l.tweens[r].run(n);return s.notifyWith(o,[l,n,t]),n<1&&i?t:(i||s.notifyWith(o,[l,1,0]),s.resolveWith(o,[l]),!1)},l=s.promise({elem:o,props:k.extend({},e),opts:k.extend(!0,{specialEasing:{},easing:k.easing._default},t),originalProperties:e,originalOptions:t,startTime:rt||ct(),duration:t.duration,tweens:[],createTween:function(e,t){var n=k.Tween(o,l.opts,e,t,l.opts.specialEasing[e]||l.opts.easing);return l.tweens.push(n),n},stop:function(e){var t=0,n=e?l.tweens.length:0;if(a)return this;for(a=!0;t<n;t++)l.tweens[t].run(1);return e?(s.notifyWith(o,[l,1,0]),s.resolveWith(o,[l,e])):s.rejectWith(o,[l,e]),this}}),c=l.props;for(!function(e,t){var n,r,i,o,a;for(n in e)if(i=t[r=V(n)],o=e[n],Array.isArray(o)&&(i=o[1],o=e[n]=o[0]),n!==r&&(e[r]=o,delete e[n]),(a=k.cssHooks[r])&&"expand"in a)for(n in o=a.expand(o),delete e[r],o)n in e||(e[n]=o[n],t[n]=i);else t[r]=i}(c,l.opts.specialEasing);r<i;r++)if(n=dt.prefilters[r].call(l,o,c,l.opts))return m(n.stop)&&(k._queueHooks(l.elem,l.opts.queue).stop=n.stop.bind(n)),n;return k.map(c,pt,l),m(l.opts.start)&&l.opts.start.call(o,l),l.progress(l.opts.progress).done(l.opts.done,l.opts.complete).fail(l.opts.fail).always(l.opts.always),k.fx.timer(k.extend(u,{elem:o,anim:l,queue:l.opts.queue})),l}k.Animation=k.extend(dt,{tweeners:{"*":[function(e,t){var n=this.createTween(e,t);return le(n.elem,e,ne.exec(t),n),n}]},tweener:function(e,t){m(e)?(t=e,e=["*"]):e=e.match(R);for(var n,r=0,i=e.length;r<i;r++)n=e[r],dt.tweeners[n]=dt.tweeners[n]||[],dt.tweeners[n].unshift(t)},prefilters:[function(e,t,n){var r,i,o,a,s,u,l,c,f="width"in t||"height"in t,p=this,d={},h=e.style,g=e.nodeType&&se(e),v=Q.get(e,"fxshow");for(r in n.queue||(null==(a=k._queueHooks(e,"fx")).unqueued&&(a.unqueued=0,s=a.empty.fire,a.empty.fire=function(){a.unqueued||s()}),a.unqueued++,p.always(function(){p.always(function(){a.unqueued--,k.queue(e,"fx").length||a.empty.fire()})})),t)if(i=t[r],st.test(i)){if(delete t[r],o=o||"toggle"===i,i===(g?"hide":"show")){if("show"!==i||!v||void 0===v[r])continue;g=!0}d[r]=v&&v[r]||k.style(e,r)}if((u=!k.isEmptyObject(t))||!k.isEmptyObject(d))for(r in f&&1===e.nodeType&&(n.overflow=[h.overflow,h.overflowX,h.overflowY],null==(l=v&&v.display)&&(l=Q.get(e,"display")),"none"===(c=k.css(e,"display"))&&(l?c=l:(fe([e],!0),l=e.style.display||l,c=k.css(e,"display"),fe([e]))),("inline"===c||"inline-block"===c&&null!=l)&&"none"===k.css(e,"float")&&(u||(p.done(function(){h.display=l}),null==l&&(c=h.display,l="none"===c?"":c)),h.display="inline-block")),n.overflow&&(h.overflow="hidden",p.always(function(){h.overflow=n.overflow[0],h.overflowX=n.overflow[1],h.overflowY=n.overflow[2]})),u=!1,d)u||(v?"hidden"in v&&(g=v.hidden):v=Q.access(e,"fxshow",{display:l}),o&&(v.hidden=!g),g&&fe([e],!0),p.done(function(){for(r in g||fe([e]),Q.remove(e,"fxshow"),d)k.style(e,r,d[r])})),u=pt(g?v[r]:0,r,p),r in v||(v[r]=u.start,g&&(u.end=u.start,u.start=0))}],prefilter:function(e,t){t?dt.prefilters.unshift(e):dt.prefilters.push(e)}}),k.speed=function(e,t,n){var r=e&&"object"==typeof e?k.extend({},e):{complete:n||!n&&t||m(e)&&e,duration:e,easing:n&&t||t&&!m(t)&&t};return k.fx.off?r.duration=0:"number"!=typeof r.duration&&(r.duration in k.fx.speeds?r.duration=k.fx.speeds[r.duration]:r.duration=k.fx.speeds._default),null!=r.queue&&!0!==r.queue||(r.queue="fx"),r.old=r.complete,r.complete=function(){m(r.old)&&r.old.call(this),r.queue&&k.dequeue(this,r.queue)},r},k.fn.extend({fadeTo:function(e,t,n,r){return this.filter(se).css("opacity",0).show().end().animate({opacity:t},e,n,r)},animate:function(t,e,n,r){var i=k.isEmptyObject(t),o=k.speed(e,n,r),a=function(){var e=dt(this,k.extend({},t),o);(i||Q.get(this,"finish"))&&e.stop(!0)};return a.finish=a,i||!1===o.queue?this.each(a):this.queue(o.queue,a)},stop:function(i,e,o){var a=function(e){var t=e.stop;delete e.stop,t(o)};return"string"!=typeof i&&(o=e,e=i,i=void 0),e&&!1!==i&&this.queue(i||"fx",[]),this.each(function(){var e=!0,t=null!=i&&i+"queueHooks",n=k.timers,r=Q.get(this);if(t)r[t]&&r[t].stop&&a(r[t]);else for(t in r)r[t]&&r[t].stop&&ut.test(t)&&a(r[t]);for(t=n.length;t--;)n[t].elem!==this||null!=i&&n[t].queue!==i||(n[t].anim.stop(o),e=!1,n.splice(t,1));!e&&o||k.dequeue(this,i)})},finish:function(a){return!1!==a&&(a=a||"fx"),this.each(function(){var e,t=Q.get(this),n=t[a+"queue"],r=t[a+"queueHooks"],i=k.timers,o=n?n.length:0;for(t.finish=!0,k.queue(this,a,[]),r&&r.stop&&r.stop.call(this,!0),e=i.length;e--;)i[e].elem===this&&i[e].queue===a&&(i[e].anim.stop(!0),i.splice(e,1));for(e=0;e<o;e++)n[e]&&n[e].finish&&n[e].finish.call(this);delete t.finish})}}),k.each(["toggle","show","hide"],function(e,r){var i=k.fn[r];k.fn[r]=function(e,t,n){return null==e||"boolean"==typeof e?i.apply(this,arguments):this.animate(ft(r,!0),e,t,n)}}),k.each({slideDown:ft("show"),slideUp:ft("hide"),slideToggle:ft("toggle"),fadeIn:{opacity:"show"},fadeOut:{opacity:"hide"},fadeToggle:{opacity:"toggle"}},function(e,r){k.fn[e]=function(e,t,n){return this.animate(r,e,t,n)}}),k.timers=[],k.fx.tick=function(){var e,t=0,n=k.timers;for(rt=Date.now();t<n.length;t++)(e=n[t])()||n[t]!==e||n.splice(t--,1);n.length||k.fx.stop(),rt=void 0},k.fx.timer=function(e){k.timers.push(e),k.fx.start()},k.fx.interval=13,k.fx.start=function(){it||(it=!0,lt())},k.fx.stop=function(){it=null},k.fx.speeds={slow:600,fast:200,_default:400},k.fn.delay=function(r,e){return r=k.fx&&k.fx.speeds[r]||r,e=e||"fx",this.queue(e,function(e,t){var n=C.setTimeout(e,r);t.stop=function(){C.clearTimeout(n)}})},ot=E.createElement("input"),at=E.createElement("select").appendChild(E.createElement("option")),ot.type="checkbox",y.checkOn=""!==ot.value,y.optSelected=at.selected,(ot=E.createElement("input")).value="t",ot.type="radio",y.radioValue="t"===ot.value;var ht,gt=k.expr.attrHandle;k.fn.extend({attr:function(e,t){return _(this,k.attr,e,t,1<arguments.length)},removeAttr:function(e){return this.each(function(){k.removeAttr(this,e)})}}),k.extend({attr:function(e,t,n){var r,i,o=e.nodeType;if(3!==o&&8!==o&&2!==o)return"undefined"==typeof e.getAttribute?k.prop(e,t,n):(1===o&&k.isXMLDoc(e)||(i=k.attrHooks[t.toLowerCase()]||(k.expr.match.bool.test(t)?ht:void 0)),void 0!==n?null===n?void k.removeAttr(e,t):i&&"set"in i&&void 0!==(r=i.set(e,n,t))?r:(e.setAttribute(t,n+""),n):i&&"get"in i&&null!==(r=i.get(e,t))?r:null==(r=k.find.attr(e,t))?void 0:r)},attrHooks:{type:{set:function(e,t){if(!y.radioValue&&"radio"===t&&A(e,"input")){var n=e.value;return e.setAttribute("type",t),n&&(e.value=n),t}}}},removeAttr:function(e,t){var n,r=0,i=t&&t.match(R);if(i&&1===e.nodeType)while(n=i[r++])e.removeAttribute(n)}}),ht={set:function(e,t,n){return!1===t?k.removeAttr(e,n):e.setAttribute(n,n),n}},k.each(k.expr.match.bool.source.match(/\w+/g),function(e,t){var a=gt[t]||k.find.attr;gt[t]=function(e,t,n){var r,i,o=t.toLowerCase();return n||(i=gt[o],gt[o]=r,r=null!=a(e,t,n)?o:null,gt[o]=i),r}});var vt=/^(?:input|select|textarea|button)$/i,yt=/^(?:a|area)$/i;function mt(e){return(e.match(R)||[]).join(" ")}function xt(e){return e.getAttribute&&e.getAttribute("class")||""}function bt(e){return Array.isArray(e)?e:"string"==typeof e&&e.match(R)||[]}k.fn.extend({prop:function(e,t){return _(this,k.prop,e,t,1<arguments.length)},removeProp:function(e){return this.each(function(){delete this[k.propFix[e]||e]})}}),k.extend({prop:function(e,t,n){var r,i,o=e.nodeType;if(3!==o&&8!==o&&2!==o)return 1===o&&k.isXMLDoc(e)||(t=k.propFix[t]||t,i=k.propHooks[t]),void 0!==n?i&&"set"in i&&void 0!==(r=i.set(e,n,t))?r:e[t]=n:i&&"get"in i&&null!==(r=i.get(e,t))?r:e[t]},propHooks:{tabIndex:{get:function(e){var t=k.find.attr(e,"tabindex");return t?parseInt(t,10):vt.test(e.nodeName)||yt.test(e.nodeName)&&e.href?0:-1}}},propFix:{"for":"htmlFor","class":"className"}}),y.optSelected||(k.propHooks.selected={get:function(e){var t=e.parentNode;return t&&t.parentNode&&t.parentNode.selectedIndex,null},set:function(e){var t=e.parentNode;t&&(t.selectedIndex,t.parentNode&&t.parentNode.selectedIndex)}}),k.each(["tabIndex","readOnly","maxLength","cellSpacing","cellPadding","rowSpan","colSpan","useMap","frameBorder","contentEditable"],function(){k.propFix[this.toLowerCase()]=this}),k.fn.extend({addClass:function(t){var e,n,r,i,o,a,s,u=0;if(m(t))return this.each(function(e){k(this).addClass(t.call(this,e,xt(this)))});if((e=bt(t)).length)while(n=this[u++])if(i=xt(n),r=1===n.nodeType&&" "+mt(i)+" "){a=0;while(o=e[a++])r.indexOf(" "+o+" ")<0&&(r+=o+" ");i!==(s=mt(r))&&n.setAttribute("class",s)}return this},removeClass:function(t){var e,n,r,i,o,a,s,u=0;if(m(t))return this.each(function(e){k(this).removeClass(t.call(this,e,xt(this)))});if(!arguments.length)return this.attr("class","");if((e=bt(t)).length)while(n=this[u++])if(i=xt(n),r=1===n.nodeType&&" "+mt(i)+" "){a=0;while(o=e[a++])while(-1<r.indexOf(" "+o+" "))r=r.replace(" "+o+" "," ");i!==(s=mt(r))&&n.setAttribute("class",s)}return this},toggleClass:function(i,t){var o=typeof i,a="string"===o||Array.isArray(i);return"boolean"==typeof t&&a?t?this.addClass(i):this.removeClass(i):m(i)?this.each(function(e){k(this).toggleClass(i.call(this,e,xt(this),t),t)}):this.each(function(){var e,t,n,r;if(a){t=0,n=k(this),r=bt(i);while(e=r[t++])n.hasClass(e)?n.removeClass(e):n.addClass(e)}else void 0!==i&&"boolean"!==o||((e=xt(this))&&Q.set(this,"__className__",e),this.setAttribute&&this.setAttribute("class",e||!1===i?"":Q.get(this,"__className__")||""))})},hasClass:function(e){var t,n,r=0;t=" "+e+" ";while(n=this[r++])if(1===n.nodeType&&-1<(" "+mt(xt(n))+" ").indexOf(t))return!0;return!1}});var wt=/\r/g;k.fn.extend({val:function(n){var r,e,i,t=this[0];return arguments.length?(i=m(n),this.each(function(e){var t;1===this.nodeType&&(null==(t=i?n.call(this,e,k(this).val()):n)?t="":"number"==typeof t?t+="":Array.isArray(t)&&(t=k.map(t,function(e){return null==e?"":e+""})),(r=k.valHooks[this.type]||k.valHooks[this.nodeName.toLowerCase()])&&"set"in r&&void 0!==r.set(this,t,"value")||(this.value=t))})):t?(r=k.valHooks[t.type]||k.valHooks[t.nodeName.toLowerCase()])&&"get"in r&&void 0!==(e=r.get(t,"value"))?e:"string"==typeof(e=t.value)?e.replace(wt,""):null==e?"":e:void 0}}),k.extend({valHooks:{option:{get:function(e){var t=k.find.attr(e,"value");return null!=t?t:mt(k.text(e))}},select:{get:function(e){var t,n,r,i=e.options,o=e.selectedIndex,a="select-one"===e.type,s=a?null:[],u=a?o+1:i.length;for(r=o<0?u:a?o:0;r<u;r++)if(((n=i[r]).selected||r===o)&&!n.disabled&&(!n.parentNode.disabled||!A(n.parentNode,"optgroup"))){if(t=k(n).val(),a)return t;s.push(t)}return s},set:function(e,t){var n,r,i=e.options,o=k.makeArray(t),a=i.length;while(a--)((r=i[a]).selected=-1<k.inArray(k.valHooks.option.get(r),o))&&(n=!0);return n||(e.selectedIndex=-1),o}}}}),k.each(["radio","checkbox"],function(){k.valHooks[this]={set:function(e,t){if(Array.isArray(t))return e.checked=-1<k.inArray(k(e).val(),t)}},y.checkOn||(k.valHooks[this].get=function(e){return null===e.getAttribute("value")?"on":e.value})}),y.focusin="onfocusin"in C;var Tt=/^(?:focusinfocus|focusoutblur)$/,Ct=function(e){e.stopPropagation()};k.extend(k.event,{trigger:function(e,t,n,r){var i,o,a,s,u,l,c,f,p=[n||E],d=v.call(e,"type")?e.type:e,h=v.call(e,"namespace")?e.namespace.split("."):[];if(o=f=a=n=n||E,3!==n.nodeType&&8!==n.nodeType&&!Tt.test(d+k.event.triggered)&&(-1<d.indexOf(".")&&(d=(h=d.split(".")).shift(),h.sort()),u=d.indexOf(":")<0&&"on"+d,(e=e[k.expando]?e:new k.Event(d,"object"==typeof e&&e)).isTrigger=r?2:3,e.namespace=h.join("."),e.rnamespace=e.namespace?new RegExp("(^|\\.)"+h.join("\\.(?:.*\\.|)")+"(\\.|$)"):null,e.result=void 0,e.target||(e.target=n),t=null==t?[e]:k.makeArray(t,[e]),c=k.event.special[d]||{},r||!c.trigger||!1!==c.trigger.apply(n,t))){if(!r&&!c.noBubble&&!x(n)){for(s=c.delegateType||d,Tt.test(s+d)||(o=o.parentNode);o;o=o.parentNode)p.push(o),a=o;a===(n.ownerDocument||E)&&p.push(a.defaultView||a.parentWindow||C)}i=0;while((o=p[i++])&&!e.isPropagationStopped())f=o,e.type=1<i?s:c.bindType||d,(l=(Q.get(o,"events")||{})[e.type]&&Q.get(o,"handle"))&&l.apply(o,t),(l=u&&o[u])&&l.apply&&G(o)&&(e.result=l.apply(o,t),!1===e.result&&e.preventDefault());return e.type=d,r||e.isDefaultPrevented()||c._default&&!1!==c._default.apply(p.pop(),t)||!G(n)||u&&m(n[d])&&!x(n)&&((a=n[u])&&(n[u]=null),k.event.triggered=d,e.isPropagationStopped()&&f.addEventListener(d,Ct),n[d](),e.isPropagationStopped()&&f.removeEventListener(d,Ct),k.event.triggered=void 0,a&&(n[u]=a)),e.result}},simulate:function(e,t,n){var r=k.extend(new k.Event,n,{type:e,isSimulated:!0});k.event.trigger(r,null,t)}}),k.fn.extend({trigger:function(e,t){return this.each(function(){k.event.trigger(e,t,this)})},triggerHandler:function(e,t){var n=this[0];if(n)return k.event.trigger(e,t,n,!0)}}),y.focusin||k.each({focus:"focusin",blur:"focusout"},function(n,r){var i=function(e){k.event.simulate(r,e.target,k.event.fix(e))};k.event.special[r]={setup:function(){var e=this.ownerDocument||this,t=Q.access(e,r);t||e.addEventListener(n,i,!0),Q.access(e,r,(t||0)+1)},teardown:function(){var e=this.ownerDocument||this,t=Q.access(e,r)-1;t?Q.access(e,r,t):(e.removeEventListener(n,i,!0),Q.remove(e,r))}}});var Et=C.location,kt=Date.now(),St=/\?/;k.parseXML=function(e){var t;if(!e||"string"!=typeof e)return null;try{t=(new C.DOMParser).parseFromString(e,"text/xml")}catch(e){t=void 0}return t&&!t.getElementsByTagName("parsererror").length||k.error("Invalid XML: "+e),t};var Nt=/\[\]$/,At=/\r?\n/g,Dt=/^(?:submit|button|image|reset|file)$/i,jt=/^(?:input|select|textarea|keygen)/i;function qt(n,e,r,i){var t;if(Array.isArray(e))k.each(e,function(e,t){r||Nt.test(n)?i(n,t):qt(n+"["+("object"==typeof t&&null!=t?e:"")+"]",t,r,i)});else if(r||"object"!==w(e))i(n,e);else for(t in e)qt(n+"["+t+"]",e[t],r,i)}k.param=function(e,t){var n,r=[],i=function(e,t){var n=m(t)?t():t;r[r.length]=encodeURIComponent(e)+"="+encodeURIComponent(null==n?"":n)};if(null==e)return"";if(Array.isArray(e)||e.jquery&&!k.isPlainObject(e))k.each(e,function(){i(this.name,this.value)});else for(n in e)qt(n,e[n],t,i);return r.join("&")},k.fn.extend({serialize:function(){return k.param(this.serializeArray())},serializeArray:function(){return this.map(function(){var e=k.prop(this,"elements");return e?k.makeArray(e):this}).filter(function(){var e=this.type;return this.name&&!k(this).is(":disabled")&&jt.test(this.nodeName)&&!Dt.test(e)&&(this.checked||!pe.test(e))}).map(function(e,t){var n=k(this).val();return null==n?null:Array.isArray(n)?k.map(n,function(e){return{name:t.name,value:e.replace(At,"\r\n")}}):{name:t.name,value:n.replace(At,"\r\n")}}).get()}});var Lt=/%20/g,Ht=/#.*$/,Ot=/([?&])_=[^&]*/,Pt=/^(.*?):[ \t]*([^\r\n]*)$/gm,Rt=/^(?:GET|HEAD)$/,Mt=/^\/\//,It={},Wt={},$t="*/".concat("*"),Ft=E.createElement("a");function Bt(o){return function(e,t){"string"!=typeof e&&(t=e,e="*");var n,r=0,i=e.toLowerCase().match(R)||[];if(m(t))while(n=i[r++])"+"===n[0]?(n=n.slice(1)||"*",(o[n]=o[n]||[]).unshift(t)):(o[n]=o[n]||[]).push(t)}}function _t(t,i,o,a){var s={},u=t===Wt;function l(e){var r;return s[e]=!0,k.each(t[e]||[],function(e,t){var n=t(i,o,a);return"string"!=typeof n||u||s[n]?u?!(r=n):void 0:(i.dataTypes.unshift(n),l(n),!1)}),r}return l(i.dataTypes[0])||!s["*"]&&l("*")}function zt(e,t){var n,r,i=k.ajaxSettings.flatOptions||{};for(n in t)void 0!==t[n]&&((i[n]?e:r||(r={}))[n]=t[n]);return r&&k.extend(!0,e,r),e}Ft.href=Et.href,k.extend({active:0,lastModified:{},etag:{},ajaxSettings:{url:Et.href,type:"GET",isLocal:/^(?:about|app|app-storage|.+-extension|file|res|widget):$/.test(Et.protocol),global:!0,processData:!0,async:!0,contentType:"application/x-www-form-urlencoded; charset=UTF-8",accepts:{"*":$t,text:"text/plain",html:"text/html",xml:"application/xml, text/xml",json:"application/json, text/javascript"},contents:{xml:/\bxml\b/,html:/\bhtml/,json:/\bjson\b/},responseFields:{xml:"responseXML",text:"responseText",json:"responseJSON"},converters:{"* text":String,"text html":!0,"text json":JSON.parse,"text xml":k.parseXML},flatOptions:{url:!0,context:!0}},ajaxSetup:function(e,t){return t?zt(zt(e,k.ajaxSettings),t):zt(k.ajaxSettings,e)},ajaxPrefilter:Bt(It),ajaxTransport:Bt(Wt),ajax:function(e,t){"object"==typeof e&&(t=e,e=void 0),t=t||{};var c,f,p,n,d,r,h,g,i,o,v=k.ajaxSetup({},t),y=v.context||v,m=v.context&&(y.nodeType||y.jquery)?k(y):k.event,x=k.Deferred(),b=k.Callbacks("once memory"),w=v.statusCode||{},a={},s={},u="canceled",T={readyState:0,getResponseHeader:function(e){var t;if(h){if(!n){n={};while(t=Pt.exec(p))n[t[1].toLowerCase()+" "]=(n[t[1].toLowerCase()+" "]||[]).concat(t[2])}t=n[e.toLowerCase()+" "]}return null==t?null:t.join(", ")},getAllResponseHeaders:function(){return h?p:null},setRequestHeader:function(e,t){return null==h&&(e=s[e.toLowerCase()]=s[e.toLowerCase()]||e,a[e]=t),this},overrideMimeType:function(e){return null==h&&(v.mimeType=e),this},statusCode:function(e){var t;if(e)if(h)T.always(e[T.status]);else for(t in e)w[t]=[w[t],e[t]];return this},abort:function(e){var t=e||u;return c&&c.abort(t),l(0,t),this}};if(x.promise(T),v.url=((e||v.url||Et.href)+"").replace(Mt,Et.protocol+"//"),v.type=t.method||t.type||v.method||v.type,v.dataTypes=(v.dataType||"*").toLowerCase().match(R)||[""],null==v.crossDomain){r=E.createElement("a");try{r.href=v.url,r.href=r.href,v.crossDomain=Ft.protocol+"//"+Ft.host!=r.protocol+"//"+r.host}catch(e){v.crossDomain=!0}}if(v.data&&v.processData&&"string"!=typeof v.data&&(v.data=k.param(v.data,v.traditional)),_t(It,v,t,T),h)return T;for(i in(g=k.event&&v.global)&&0==k.active++&&k.event.trigger("ajaxStart"),v.type=v.type.toUpperCase(),v.hasContent=!Rt.test(v.type),f=v.url.replace(Ht,""),v.hasContent?v.data&&v.processData&&0===(v.contentType||"").indexOf("application/x-www-form-urlencoded")&&(v.data=v.data.replace(Lt,"+")):(o=v.url.slice(f.length),v.data&&(v.processData||"string"==typeof v.data)&&(f+=(St.test(f)?"&":"?")+v.data,delete v.data),!1===v.cache&&(f=f.replace(Ot,"$1"),o=(St.test(f)?"&":"?")+"_="+kt+++o),v.url=f+o),v.ifModified&&(k.lastModified[f]&&T.setRequestHeader("If-Modified-Since",k.lastModified[f]),k.etag[f]&&T.setRequestHeader("If-None-Match",k.etag[f])),(v.data&&v.hasContent&&!1!==v.contentType||t.contentType)&&T.setRequestHeader("Content-Type",v.contentType),T.setRequestHeader("Accept",v.dataTypes[0]&&v.accepts[v.dataTypes[0]]?v.accepts[v.dataTypes[0]]+("*"!==v.dataTypes[0]?", "+$t+"; q=0.01":""):v.accepts["*"]),v.headers)T.setRequestHeader(i,v.headers[i]);if(v.beforeSend&&(!1===v.beforeSend.call(y,T,v)||h))return T.abort();if(u="abort",b.add(v.complete),T.done(v.success),T.fail(v.error),c=_t(Wt,v,t,T)){if(T.readyState=1,g&&m.trigger("ajaxSend",[T,v]),h)return T;v.async&&0<v.timeout&&(d=C.setTimeout(function(){T.abort("timeout")},v.timeout));try{h=!1,c.send(a,l)}catch(e){if(h)throw e;l(-1,e)}}else l(-1,"No Transport");function l(e,t,n,r){var i,o,a,s,u,l=t;h||(h=!0,d&&C.clearTimeout(d),c=void 0,p=r||"",T.readyState=0<e?4:0,i=200<=e&&e<300||304===e,n&&(s=function(e,t,n){var r,i,o,a,s=e.contents,u=e.dataTypes;while("*"===u[0])u.shift(),void 0===r&&(r=e.mimeType||t.getResponseHeader("Content-Type"));if(r)for(i in s)if(s[i]&&s[i].test(r)){u.unshift(i);break}if(u[0]in n)o=u[0];else{for(i in n){if(!u[0]||e.converters[i+" "+u[0]]){o=i;break}a||(a=i)}o=o||a}if(o)return o!==u[0]&&u.unshift(o),n[o]}(v,T,n)),s=function(e,t,n,r){var i,o,a,s,u,l={},c=e.dataTypes.slice();if(c[1])for(a in e.converters)l[a.toLowerCase()]=e.converters[a];o=c.shift();while(o)if(e.responseFields[o]&&(n[e.responseFields[o]]=t),!u&&r&&e.dataFilter&&(t=e.dataFilter(t,e.dataType)),u=o,o=c.shift())if("*"===o)o=u;else if("*"!==u&&u!==o){if(!(a=l[u+" "+o]||l["* "+o]))for(i in l)if((s=i.split(" "))[1]===o&&(a=l[u+" "+s[0]]||l["* "+s[0]])){!0===a?a=l[i]:!0!==l[i]&&(o=s[0],c.unshift(s[1]));break}if(!0!==a)if(a&&e["throws"])t=a(t);else try{t=a(t)}catch(e){return{state:"parsererror",error:a?e:"No conversion from "+u+" to "+o}}}return{state:"success",data:t}}(v,s,T,i),i?(v.ifModified&&((u=T.getResponseHeader("Last-Modified"))&&(k.lastModified[f]=u),(u=T.getResponseHeader("etag"))&&(k.etag[f]=u)),204===e||"HEAD"===v.type?l="nocontent":304===e?l="notmodified":(l=s.state,o=s.data,i=!(a=s.error))):(a=l,!e&&l||(l="error",e<0&&(e=0))),T.status=e,T.statusText=(t||l)+"",i?x.resolveWith(y,[o,l,T]):x.rejectWith(y,[T,l,a]),T.statusCode(w),w=void 0,g&&m.trigger(i?"ajaxSuccess":"ajaxError",[T,v,i?o:a]),b.fireWith(y,[T,l]),g&&(m.trigger("ajaxComplete",[T,v]),--k.active||k.event.trigger("ajaxStop")))}return T},getJSON:function(e,t,n){return k.get(e,t,n,"json")},getScript:function(e,t){return k.get(e,void 0,t,"script")}}),k.each(["get","post"],function(e,i){k[i]=function(e,t,n,r){return m(t)&&(r=r||n,n=t,t=void 0),k.ajax(k.extend({url:e,type:i,dataType:r,data:t,success:n},k.isPlainObject(e)&&e))}}),k._evalUrl=function(e,t){return k.ajax({url:e,type:"GET",dataType:"script",cache:!0,async:!1,global:!1,converters:{"text script":function(){}},dataFilter:function(e){k.globalEval(e,t)}})},k.fn.extend({wrapAll:function(e){var t;return this[0]&&(m(e)&&(e=e.call(this[0])),t=k(e,this[0].ownerDocument).eq(0).clone(!0),this[0].parentNode&&t.insertBefore(this[0]),t.map(function(){var e=this;while(e.firstElementChild)e=e.firstElementChild;return e}).append(this)),this},wrapInner:function(n){return m(n)?this.each(function(e){k(this).wrapInner(n.call(this,e))}):this.each(function(){var e=k(this),t=e.contents();t.length?t.wrapAll(n):e.append(n)})},wrap:function(t){var n=m(t);return this.each(function(e){k(this).wrapAll(n?t.call(this,e):t)})},unwrap:function(e){return this.parent(e).not("body").each(function(){k(this).replaceWith(this.childNodes)}),this}}),k.expr.pseudos.hidden=function(e){return!k.expr.pseudos.visible(e)},k.expr.pseudos.visible=function(e){return!!(e.offsetWidth||e.offsetHeight||e.getClientRects().length)},k.ajaxSettings.xhr=function(){try{return new C.XMLHttpRequest}catch(e){}};var Ut={0:200,1223:204},Xt=k.ajaxSettings.xhr();y.cors=!!Xt&&"withCredentials"in Xt,y.ajax=Xt=!!Xt,k.ajaxTransport(function(i){var o,a;if(y.cors||Xt&&!i.crossDomain)return{send:function(e,t){var n,r=i.xhr();if(r.open(i.type,i.url,i.async,i.username,i.password),i.xhrFields)for(n in i.xhrFields)r[n]=i.xhrFields[n];for(n in i.mimeType&&r.overrideMimeType&&r.overrideMimeType(i.mimeType),i.crossDomain||e["X-Requested-With"]||(e["X-Requested-With"]="XMLHttpRequest"),e)r.setRequestHeader(n,e[n]);o=function(e){return function(){o&&(o=a=r.onload=r.onerror=r.onabort=r.ontimeout=r.onreadystatechange=null,"abort"===e?r.abort():"error"===e?"number"!=typeof r.status?t(0,"error"):t(r.status,r.statusText):t(Ut[r.status]||r.status,r.statusText,"text"!==(r.responseType||"text")||"string"!=typeof r.responseText?{binary:r.response}:{text:r.responseText},r.getAllResponseHeaders()))}},r.onload=o(),a=r.onerror=r.ontimeout=o("error"),void 0!==r.onabort?r.onabort=a:r.onreadystatechange=function(){4===r.readyState&&C.setTimeout(function(){o&&a()})},o=o("abort");try{r.send(i.hasContent&&i.data||null)}catch(e){if(o)throw e}},abort:function(){o&&o()}}}),k.ajaxPrefilter(function(e){e.crossDomain&&(e.contents.script=!1)}),k.ajaxSetup({accepts:{script:"text/javascript, application/javascript, application/ecmascript, application/x-ecmascript"},contents:{script:/\b(?:java|ecma)script\b/},converters:{"text script":function(e){return k.globalEval(e),e}}}),k.ajaxPrefilter("script",function(e){void 0===e.cache&&(e.cache=!1),e.crossDomain&&(e.type="GET")}),k.ajaxTransport("script",function(n){var r,i;if(n.crossDomain||n.scriptAttrs)return{send:function(e,t){r=k("<script>").attr(n.scriptAttrs||{}).prop({charset:n.scriptCharset,src:n.url}).on("load error",i=function(e){r.remove(),i=null,e&&t("error"===e.type?404:200,e.type)}),E.head.appendChild(r[0])},abort:function(){i&&i()}}});var Vt,Gt=[],Yt=/(=)\?(?=&|$)|\?\?/;k.ajaxSetup({jsonp:"callback",jsonpCallback:function(){var e=Gt.pop()||k.expando+"_"+kt++;return this[e]=!0,e}}),k.ajaxPrefilter("json jsonp",function(e,t,n){var r,i,o,a=!1!==e.jsonp&&(Yt.test(e.url)?"url":"string"==typeof e.data&&0===(e.contentType||"").indexOf("application/x-www-form-urlencoded")&&Yt.test(e.data)&&"data");if(a||"jsonp"===e.dataTypes[0])return r=e.jsonpCallback=m(e.jsonpCallback)?e.jsonpCallback():e.jsonpCallback,a?e[a]=e[a].replace(Yt,"$1"+r):!1!==e.jsonp&&(e.url+=(St.test(e.url)?"&":"?")+e.jsonp+"="+r),e.converters["script json"]=function(){return o||k.error(r+" was not called"),o[0]},e.dataTypes[0]="json",i=C[r],C[r]=function(){o=arguments},n.always(function(){void 0===i?k(C).removeProp(r):C[r]=i,e[r]&&(e.jsonpCallback=t.jsonpCallback,Gt.push(r)),o&&m(i)&&i(o[0]),o=i=void 0}),"script"}),y.createHTMLDocument=((Vt=E.implementation.createHTMLDocument("").body).innerHTML="<form></form><form></form>",2===Vt.childNodes.length),k.parseHTML=function(e,t,n){return"string"!=typeof e?[]:("boolean"==typeof t&&(n=t,t=!1),t||(y.createHTMLDocument?((r=(t=E.implementation.createHTMLDocument("")).createElement("base")).href=E.location.href,t.head.appendChild(r)):t=E),o=!n&&[],(i=D.exec(e))?[t.createElement(i[1])]:(i=we([e],t,o),o&&o.length&&k(o).remove(),k.merge([],i.childNodes)));var r,i,o},k.fn.load=function(e,t,n){var r,i,o,a=this,s=e.indexOf(" ");return-1<s&&(r=mt(e.slice(s)),e=e.slice(0,s)),m(t)?(n=t,t=void 0):t&&"object"==typeof t&&(i="POST"),0<a.length&&k.ajax({url:e,type:i||"GET",dataType:"html",data:t}).done(function(e){o=arguments,a.html(r?k("<div>").append(k.parseHTML(e)).find(r):e)}).always(n&&function(e,t){a.each(function(){n.apply(this,o||[e.responseText,t,e])})}),this},k.each(["ajaxStart","ajaxStop","ajaxComplete","ajaxError","ajaxSuccess","ajaxSend"],function(e,t){k.fn[t]=function(e){return this.on(t,e)}}),k.expr.pseudos.animated=function(t){return k.grep(k.timers,function(e){return t===e.elem}).length},k.offset={setOffset:function(e,t,n){var r,i,o,a,s,u,l=k.css(e,"position"),c=k(e),f={};"static"===l&&(e.style.position="relative"),s=c.offset(),o=k.css(e,"top"),u=k.css(e,"left"),("absolute"===l||"fixed"===l)&&-1<(o+u).indexOf("auto")?(a=(r=c.position()).top,i=r.left):(a=parseFloat(o)||0,i=parseFloat(u)||0),m(t)&&(t=t.call(e,n,k.extend({},s))),null!=t.top&&(f.top=t.top-s.top+a),null!=t.left&&(f.left=t.left-s.left+i),"using"in t?t.using.call(e,f):c.css(f)}},k.fn.extend({offset:function(t){if(arguments.length)return void 0===t?this:this.each(function(e){k.offset.setOffset(this,t,e)});var e,n,r=this[0];return r?r.getClientRects().length?(e=r.getBoundingClientRect(),n=r.ownerDocument.defaultView,{top:e.top+n.pageYOffset,left:e.left+n.pageXOffset}):{top:0,left:0}:void 0},position:function(){if(this[0]){var e,t,n,r=this[0],i={top:0,left:0};if("fixed"===k.css(r,"position"))t=r.getBoundingClientRect();else{t=this.offset(),n=r.ownerDocument,e=r.offsetParent||n.documentElement;while(e&&(e===n.body||e===n.documentElement)&&"static"===k.css(e,"position"))e=e.parentNode;e&&e!==r&&1===e.nodeType&&((i=k(e).offset()).top+=k.css(e,"borderTopWidth",!0),i.left+=k.css(e,"borderLeftWidth",!0))}return{top:t.top-i.top-k.css(r,"marginTop",!0),left:t.left-i.left-k.css(r,"marginLeft",!0)}}},offsetParent:function(){return this.map(function(){var e=this.offsetParent;while(e&&"static"===k.css(e,"position"))e=e.offsetParent;return e||ie})}}),k.each({scrollLeft:"pageXOffset",scrollTop:"pageYOffset"},function(t,i){var o="pageYOffset"===i;k.fn[t]=function(e){return _(this,function(e,t,n){var r;if(x(e)?r=e:9===e.nodeType&&(r=e.defaultView),void 0===n)return r?r[i]:e[t];r?r.scrollTo(o?r.pageXOffset:n,o?n:r.pageYOffset):e[t]=n},t,e,arguments.length)}}),k.each(["top","left"],function(e,n){k.cssHooks[n]=ze(y.pixelPosition,function(e,t){if(t)return t=_e(e,n),$e.test(t)?k(e).position()[n]+"px":t})}),k.each({Height:"height",Width:"width"},function(a,s){k.each({padding:"inner"+a,content:s,"":"outer"+a},function(r,o){k.fn[o]=function(e,t){var n=arguments.length&&(r||"boolean"!=typeof e),i=r||(!0===e||!0===t?"margin":"border");return _(this,function(e,t,n){var r;return x(e)?0===o.indexOf("outer")?e["inner"+a]:e.document.documentElement["client"+a]:9===e.nodeType?(r=e.documentElement,Math.max(e.body["scroll"+a],r["scroll"+a],e.body["offset"+a],r["offset"+a],r["client"+a])):void 0===n?k.css(e,t,i):k.style(e,t,n,i)},s,n?e:void 0,n)}})}),k.each("blur focus focusin focusout resize scroll click dblclick mousedown mouseup mousemove mouseover mouseout mouseenter mouseleave change select submit keydown keypress keyup contextmenu".split(" "),function(e,n){k.fn[n]=function(e,t){return 0<arguments.length?this.on(n,null,e,t):this.trigger(n)}}),k.fn.extend({hover:function(e,t){return this.mouseenter(e).mouseleave(t||e)}}),k.fn.extend({bind:function(e,t,n){return this.on(e,null,t,n)},unbind:function(e,t){return this.off(e,null,t)},delegate:function(e,t,n,r){return this.on(t,e,n,r)},undelegate:function(e,t,n){return 1===arguments.length?this.off(e,"**"):this.off(t,e||"**",n)}}),k.proxy=function(e,t){var n,r,i;if("string"==typeof t&&(n=e[t],t=e,e=n),m(e))return r=s.call(arguments,2),(i=function(){return e.apply(t||this,r.concat(s.call(arguments)))}).guid=e.guid=e.guid||k.guid++,i},k.holdReady=function(e){e?k.readyWait++:k.ready(!0)},k.isArray=Array.isArray,k.parseJSON=JSON.parse,k.nodeName=A,k.isFunction=m,k.isWindow=x,k.camelCase=V,k.type=w,k.now=Date.now,k.isNumeric=function(e){var t=k.type(e);return("number"===t||"string"===t)&&!isNaN(e-parseFloat(e))},"function"==typeof define&&define.amd&&define("jquery",[],function(){return k});var Qt=C.jQuery,Jt=C.$;return k.noConflict=function(e){return C.$===k&&(C.$=Jt),e&&C.jQuery===k&&(C.jQuery=Qt),k},e||(C.jQuery=C.$=k),k}); diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 276c271bbe..ba8048b23f 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -1,37 +1,52 @@ window.matrixLogin = { endpoint: location.origin + "/_matrix/client/r0/login", serverAcceptsPassword: false, - serverAcceptsCas: false, serverAcceptsSso: false, }; -var submitPassword = function(user, pwd) { - console.log("Logging in with password..."); - var data = { - type: "m.login.password", - user: user, - password: pwd, - }; - $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); - matrixLogin.onLogin(response); - }).error(errorFunc); -}; +// Titles get updated through the process to give users feedback. +var TITLE_PRE_AUTH = "Log in with one of the following methods"; +var TITLE_POST_AUTH = "Logging in..."; + +// The cookie used to store the original query parameters when using SSO. +var COOKIE_KEY = "synapse_login_fallback_qs"; + +/* + * Submit a login request. + * + * type: The login type as a string (e.g. "m.login.foo"). + * data: An object of data specific to the login type. + * extra: (Optional) An object to search for extra information to send with the + * login request, e.g. device_id. + * callback: (Optional) Function to call on successful login. + */ +var submitLogin = function(type, data, extra, callback) { + console.log("Logging in with " + type); + set_title(TITLE_POST_AUTH); + + // Add the login type. + data.type = type; + + // Add the device information, if it was provided. + if (extra.device_id) { + data.device_id = extra.device_id; + } + if (extra.initial_device_display_name) { + data.initial_device_display_name = extra.initial_device_display_name; + } -var submitToken = function(loginToken) { - console.log("Logging in with login token..."); - var data = { - type: "m.login.token", - token: loginToken - }; $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { - show_login(); + if (callback) { + callback(); + } matrixLogin.onLogin(response); - }).error(errorFunc); + }).fail(errorFunc); }; var errorFunc = function(err) { - show_login(); + // We want to show the error to the user rather than redirecting immediately to the + // SSO portal (if SSO is the only login option), so we inhibit the redirect. + show_login(true); if (err.responseJSON && err.responseJSON.error) { setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")"); @@ -45,26 +60,40 @@ var setFeedbackString = function(text) { $("#feedback").text(text); }; -var show_login = function() { - $("#loading").hide(); - +var show_login = function(inhibit_redirect) { + // Set the redirect to come back to this page, a login token will get added + // and handled after the redirect. var this_page = window.location.origin + window.location.pathname; $("#sso_redirect_url").val(this_page); - if (matrixLogin.serverAcceptsPassword) { - $("#password_flow").show(); - } - + // If inhibit_redirect is false, and SSO is the only supported login method, + // we can redirect straight to the SSO page. if (matrixLogin.serverAcceptsSso) { + // Before submitting SSO, set the current query parameters into a cookie + // for retrieval later. + var qs = parseQsFromUrl(); + setCookie(COOKIE_KEY, JSON.stringify(qs)); + + if (!inhibit_redirect && !matrixLogin.serverAcceptsPassword) { + $("#sso_form").submit(); + return; + } + + // Otherwise, show the SSO form $("#sso_flow").show(); - } else if (matrixLogin.serverAcceptsCas) { - $("#sso_form").attr("action", "/_matrix/client/r0/login/cas/redirect"); - $("#sso_flow").show(); } - if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) { + if (matrixLogin.serverAcceptsPassword) { + $("#password_flow").show(); + } + + if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsSso) { $("#no_login_types").show(); } + + set_title(TITLE_PRE_AUTH); + + $("#loading").hide(); }; var show_spinner = function() { @@ -74,17 +103,15 @@ var show_spinner = function() { $("#loading").show(); }; +var set_title = function(title) { + $("#title").text(title); +}; var fetch_info = function(cb) { $.get(matrixLogin.endpoint, function(response) { var serverAcceptsPassword = false; - var serverAcceptsCas = false; for (var i=0; i<response.flows.length; i++) { var flow = response.flows[i]; - if ("m.login.cas" === flow.type) { - matrixLogin.serverAcceptsCas = true; - console.log("Server accepts CAS"); - } if ("m.login.sso" === flow.type) { matrixLogin.serverAcceptsSso = true; console.log("Server accepts SSO"); @@ -96,13 +123,13 @@ var fetch_info = function(cb) { } cb(); - }).error(errorFunc); + }).fail(errorFunc); } matrixLogin.onLoad = function() { fetch_info(function() { if (!try_token()) { - show_login(); + show_login(false); } }); }; @@ -114,16 +141,27 @@ matrixLogin.password_login = function() { setFeedbackString(""); show_spinner(); - submitPassword(user, pwd); + submitLogin( + "m.login.password", + {user: user, password: pwd}, + parseQsFromUrl()); }; matrixLogin.onLogin = function(response) { // clobber this function - console.log("onLogin - This function should be replaced to proceed."); - console.log(response); + console.warn("onLogin - This function should be replaced to proceed."); }; -var parseQsFromUrl = function(query) { +/* + * Process the query parameters from the current URL into an object. + */ +var parseQsFromUrl = function() { + var pos = window.location.href.indexOf("?"); + if (pos == -1) { + return {}; + } + var query = window.location.href.substr(pos + 1); + var result = {}; query.split("&").forEach(function(part) { var item = part.split("="); @@ -133,25 +171,80 @@ var parseQsFromUrl = function(query) { if (val) { val = decodeURIComponent(val); } - result[key] = val + result[key] = val; + }); + return result; +}; + +/* + * Process the cookies and return an object. + */ +var parseCookies = function() { + var allCookies = document.cookie; + var result = {}; + allCookies.split(";").forEach(function(part) { + var item = part.split("="); + // Cookies might have arbitrary whitespace between them. + var key = item[0].trim(); + // You can end up with a broken cookie that doesn't have an equals sign + // in it. Set to an empty value. + var val = (item[1] || "").trim(); + // Values might be URI encoded. + if (val) { + val = decodeURIComponent(val); + } + result[key] = val; }); return result; }; +/* + * Set a cookie that is valid for 1 hour. + */ +var setCookie = function(key, value) { + // The maximum age is set in seconds. + var maxAge = 60 * 60; + // Set the cookie, this defaults to the current domain and path. + document.cookie = key + "=" + encodeURIComponent(value) + ";max-age=" + maxAge + ";sameSite=lax"; +}; + +/* + * Removes a cookie by key. + */ +var deleteCookie = function(key) { + // Delete a cookie by setting the expiration to 0. (Note that the value + // doesn't matter.) + document.cookie = key + "=deleted;expires=0"; +}; + +/* + * Submits the login token if one is found in the query parameters. Returns a + * boolean of whether the login token was found or not. + */ var try_token = function() { - var pos = window.location.href.indexOf("?"); - if (pos == -1) { - return false; - } - var qs = parseQsFromUrl(window.location.href.substr(pos+1)); + // Check if the login token is in the query parameters. + var qs = parseQsFromUrl(); var loginToken = qs.loginToken; - if (!loginToken) { return false; } - submitToken(loginToken); + // Retrieve the original query parameters (from before the SSO redirect). + // They are stored as JSON in a cookie. + var cookies = parseCookies(); + var original_query_params = JSON.parse(cookies[COOKIE_KEY] || "{}") + + // If the login is successful, delete the cookie. + var callback = function() { + deleteCookie(COOKIE_KEY); + } + + submitLogin( + "m.login.token", + {token: loginToken}, + original_query_params, + callback); return true; }; diff --git a/synapse/static/client/register/index.html b/synapse/static/client/register/index.html index 6edc4deb03..140653574d 100644 --- a/synapse/static/client/register/index.html +++ b/synapse/static/client/register/index.html @@ -1,9 +1,10 @@ +<!doctype html> <html> <head> <title> Registration </title> <meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> <link rel="stylesheet" href="style.css"> -<script src="js/jquery-2.1.3.min.js"></script> +<script src="js/jquery-3.4.1.min.js"></script> <script src="https://www.recaptcha.net/recaptcha/api/js/recaptcha_ajax.js"></script> <script src="register_config.js"></script> <script src="js/register.js"></script> diff --git a/synapse/static/client/register/js/jquery-2.1.3.min.js b/synapse/static/client/register/js/jquery-2.1.3.min.js deleted file mode 100644 index 25714ed29a..0000000000 --- a/synapse/static/client/register/js/jquery-2.1.3.min.js +++ /dev/null @@ -1,4 +0,0 @@ -/*! jQuery v2.1.3 | (c) 2005, 2014 jQuery Foundation, Inc. | jquery.org/license */ -!function(a,b){"object"==typeof module&&"object"==typeof module.exports?module.exports=a.document?b(a,!0):function(a){if(!a.document)throw new Error("jQuery requires a window with a document");return b(a)}:b(a)}("undefined"!=typeof window?window:this,function(a,b){var c=[],d=c.slice,e=c.concat,f=c.push,g=c.indexOf,h={},i=h.toString,j=h.hasOwnProperty,k={},l=a.document,m="2.1.3",n=function(a,b){return new n.fn.init(a,b)},o=/^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g,p=/^-ms-/,q=/-([\da-z])/gi,r=function(a,b){return b.toUpperCase()};n.fn=n.prototype={jquery:m,constructor:n,selector:"",length:0,toArray:function(){return d.call(this)},get:function(a){return null!=a?0>a?this[a+this.length]:this[a]:d.call(this)},pushStack:function(a){var b=n.merge(this.constructor(),a);return b.prevObject=this,b.context=this.context,b},each:function(a,b){return n.each(this,a,b)},map:function(a){return this.pushStack(n.map(this,function(b,c){return a.call(b,c,b)}))},slice:function(){return this.pushStack(d.apply(this,arguments))},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},eq:function(a){var b=this.length,c=+a+(0>a?b:0);return this.pushStack(c>=0&&b>c?[this[c]]:[])},end:function(){return this.prevObject||this.constructor(null)},push:f,sort:c.sort,splice:c.splice},n.extend=n.fn.extend=function(){var a,b,c,d,e,f,g=arguments[0]||{},h=1,i=arguments.length,j=!1;for("boolean"==typeof g&&(j=g,g=arguments[h]||{},h++),"object"==typeof g||n.isFunction(g)||(g={}),h===i&&(g=this,h--);i>h;h++)if(null!=(a=arguments[h]))for(b in a)c=g[b],d=a[b],g!==d&&(j&&d&&(n.isPlainObject(d)||(e=n.isArray(d)))?(e?(e=!1,f=c&&n.isArray(c)?c:[]):f=c&&n.isPlainObject(c)?c:{},g[b]=n.extend(j,f,d)):void 0!==d&&(g[b]=d));return g},n.extend({expando:"jQuery"+(m+Math.random()).replace(/\D/g,""),isReady:!0,error:function(a){throw new Error(a)},noop:function(){},isFunction:function(a){return"function"===n.type(a)},isArray:Array.isArray,isWindow:function(a){return null!=a&&a===a.window},isNumeric:function(a){return!n.isArray(a)&&a-parseFloat(a)+1>=0},isPlainObject:function(a){return"object"!==n.type(a)||a.nodeType||n.isWindow(a)?!1:a.constructor&&!j.call(a.constructor.prototype,"isPrototypeOf")?!1:!0},isEmptyObject:function(a){var b;for(b in a)return!1;return!0},type:function(a){return null==a?a+"":"object"==typeof a||"function"==typeof a?h[i.call(a)]||"object":typeof a},globalEval:function(a){var b,c=eval;a=n.trim(a),a&&(1===a.indexOf("use strict")?(b=l.createElement("script"),b.text=a,l.head.appendChild(b).parentNode.removeChild(b)):c(a))},camelCase:function(a){return a.replace(p,"ms-").replace(q,r)},nodeName:function(a,b){return a.nodeName&&a.nodeName.toLowerCase()===b.toLowerCase()},each:function(a,b,c){var d,e=0,f=a.length,g=s(a);if(c){if(g){for(;f>e;e++)if(d=b.apply(a[e],c),d===!1)break}else for(e in a)if(d=b.apply(a[e],c),d===!1)break}else if(g){for(;f>e;e++)if(d=b.call(a[e],e,a[e]),d===!1)break}else for(e in a)if(d=b.call(a[e],e,a[e]),d===!1)break;return a},trim:function(a){return null==a?"":(a+"").replace(o,"")},makeArray:function(a,b){var c=b||[];return null!=a&&(s(Object(a))?n.merge(c,"string"==typeof a?[a]:a):f.call(c,a)),c},inArray:function(a,b,c){return null==b?-1:g.call(b,a,c)},merge:function(a,b){for(var c=+b.length,d=0,e=a.length;c>d;d++)a[e++]=b[d];return a.length=e,a},grep:function(a,b,c){for(var d,e=[],f=0,g=a.length,h=!c;g>f;f++)d=!b(a[f],f),d!==h&&e.push(a[f]);return e},map:function(a,b,c){var d,f=0,g=a.length,h=s(a),i=[];if(h)for(;g>f;f++)d=b(a[f],f,c),null!=d&&i.push(d);else for(f in a)d=b(a[f],f,c),null!=d&&i.push(d);return e.apply([],i)},guid:1,proxy:function(a,b){var c,e,f;return"string"==typeof b&&(c=a[b],b=a,a=c),n.isFunction(a)?(e=d.call(arguments,2),f=function(){return a.apply(b||this,e.concat(d.call(arguments)))},f.guid=a.guid=a.guid||n.guid++,f):void 0},now:Date.now,support:k}),n.each("Boolean Number String Function Array Date RegExp Object Error".split(" "),function(a,b){h["[object "+b+"]"]=b.toLowerCase()});function s(a){var b=a.length,c=n.type(a);return"function"===c||n.isWindow(a)?!1:1===a.nodeType&&b?!0:"array"===c||0===b||"number"==typeof b&&b>0&&b-1 in a}var t=function(a){var b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u="sizzle"+1*new Date,v=a.document,w=0,x=0,y=hb(),z=hb(),A=hb(),B=function(a,b){return a===b&&(l=!0),0},C=1<<31,D={}.hasOwnProperty,E=[],F=E.pop,G=E.push,H=E.push,I=E.slice,J=function(a,b){for(var c=0,d=a.length;d>c;c++)if(a[c]===b)return c;return-1},K="checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped",L="[\\x20\\t\\r\\n\\f]",M="(?:\\\\.|[\\w-]|[^\\x00-\\xa0])+",N=M.replace("w","w#"),O="\\["+L+"*("+M+")(?:"+L+"*([*^$|!~]?=)"+L+"*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|("+N+"))|)"+L+"*\\]",P=":("+M+")(?:\\((('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|((?:\\\\.|[^\\\\()[\\]]|"+O+")*)|.*)\\)|)",Q=new RegExp(L+"+","g"),R=new RegExp("^"+L+"+|((?:^|[^\\\\])(?:\\\\.)*)"+L+"+$","g"),S=new RegExp("^"+L+"*,"+L+"*"),T=new RegExp("^"+L+"*([>+~]|"+L+")"+L+"*"),U=new RegExp("="+L+"*([^\\]'\"]*?)"+L+"*\\]","g"),V=new RegExp(P),W=new RegExp("^"+N+"$"),X={ID:new RegExp("^#("+M+")"),CLASS:new RegExp("^\\.("+M+")"),TAG:new RegExp("^("+M.replace("w","w*")+")"),ATTR:new RegExp("^"+O),PSEUDO:new RegExp("^"+P),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+L+"*(even|odd|(([+-]|)(\\d*)n|)"+L+"*(?:([+-]|)"+L+"*(\\d+)|))"+L+"*\\)|)","i"),bool:new RegExp("^(?:"+K+")$","i"),needsContext:new RegExp("^"+L+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+L+"*((?:-\\d)?\\d*)"+L+"*\\)|)(?=[^-]|$)","i")},Y=/^(?:input|select|textarea|button)$/i,Z=/^h\d$/i,$=/^[^{]+\{\s*\[native \w/,_=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,ab=/[+~]/,bb=/'|\\/g,cb=new RegExp("\\\\([\\da-f]{1,6}"+L+"?|("+L+")|.)","ig"),db=function(a,b,c){var d="0x"+b-65536;return d!==d||c?b:0>d?String.fromCharCode(d+65536):String.fromCharCode(d>>10|55296,1023&d|56320)},eb=function(){m()};try{H.apply(E=I.call(v.childNodes),v.childNodes),E[v.childNodes.length].nodeType}catch(fb){H={apply:E.length?function(a,b){G.apply(a,I.call(b))}:function(a,b){var c=a.length,d=0;while(a[c++]=b[d++]);a.length=c-1}}}function gb(a,b,d,e){var f,h,j,k,l,o,r,s,w,x;if((b?b.ownerDocument||b:v)!==n&&m(b),b=b||n,d=d||[],k=b.nodeType,"string"!=typeof a||!a||1!==k&&9!==k&&11!==k)return d;if(!e&&p){if(11!==k&&(f=_.exec(a)))if(j=f[1]){if(9===k){if(h=b.getElementById(j),!h||!h.parentNode)return d;if(h.id===j)return d.push(h),d}else if(b.ownerDocument&&(h=b.ownerDocument.getElementById(j))&&t(b,h)&&h.id===j)return d.push(h),d}else{if(f[2])return H.apply(d,b.getElementsByTagName(a)),d;if((j=f[3])&&c.getElementsByClassName)return H.apply(d,b.getElementsByClassName(j)),d}if(c.qsa&&(!q||!q.test(a))){if(s=r=u,w=b,x=1!==k&&a,1===k&&"object"!==b.nodeName.toLowerCase()){o=g(a),(r=b.getAttribute("id"))?s=r.replace(bb,"\\$&"):b.setAttribute("id",s),s="[id='"+s+"'] ",l=o.length;while(l--)o[l]=s+rb(o[l]);w=ab.test(a)&&pb(b.parentNode)||b,x=o.join(",")}if(x)try{return H.apply(d,w.querySelectorAll(x)),d}catch(y){}finally{r||b.removeAttribute("id")}}}return i(a.replace(R,"$1"),b,d,e)}function hb(){var a=[];function b(c,e){return a.push(c+" ")>d.cacheLength&&delete b[a.shift()],b[c+" "]=e}return b}function ib(a){return a[u]=!0,a}function jb(a){var b=n.createElement("div");try{return!!a(b)}catch(c){return!1}finally{b.parentNode&&b.parentNode.removeChild(b),b=null}}function kb(a,b){var c=a.split("|"),e=a.length;while(e--)d.attrHandle[c[e]]=b}function lb(a,b){var c=b&&a,d=c&&1===a.nodeType&&1===b.nodeType&&(~b.sourceIndex||C)-(~a.sourceIndex||C);if(d)return d;if(c)while(c=c.nextSibling)if(c===b)return-1;return a?1:-1}function mb(a){return function(b){var c=b.nodeName.toLowerCase();return"input"===c&&b.type===a}}function nb(a){return function(b){var c=b.nodeName.toLowerCase();return("input"===c||"button"===c)&&b.type===a}}function ob(a){return ib(function(b){return b=+b,ib(function(c,d){var e,f=a([],c.length,b),g=f.length;while(g--)c[e=f[g]]&&(c[e]=!(d[e]=c[e]))})})}function pb(a){return a&&"undefined"!=typeof a.getElementsByTagName&&a}c=gb.support={},f=gb.isXML=function(a){var b=a&&(a.ownerDocument||a).documentElement;return b?"HTML"!==b.nodeName:!1},m=gb.setDocument=function(a){var b,e,g=a?a.ownerDocument||a:v;return g!==n&&9===g.nodeType&&g.documentElement?(n=g,o=g.documentElement,e=g.defaultView,e&&e!==e.top&&(e.addEventListener?e.addEventListener("unload",eb,!1):e.attachEvent&&e.attachEvent("onunload",eb)),p=!f(g),c.attributes=jb(function(a){return a.className="i",!a.getAttribute("className")}),c.getElementsByTagName=jb(function(a){return a.appendChild(g.createComment("")),!a.getElementsByTagName("*").length}),c.getElementsByClassName=$.test(g.getElementsByClassName),c.getById=jb(function(a){return o.appendChild(a).id=u,!g.getElementsByName||!g.getElementsByName(u).length}),c.getById?(d.find.ID=function(a,b){if("undefined"!=typeof b.getElementById&&p){var c=b.getElementById(a);return c&&c.parentNode?[c]:[]}},d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){return a.getAttribute("id")===b}}):(delete d.find.ID,d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){var c="undefined"!=typeof a.getAttributeNode&&a.getAttributeNode("id");return c&&c.value===b}}),d.find.TAG=c.getElementsByTagName?function(a,b){return"undefined"!=typeof b.getElementsByTagName?b.getElementsByTagName(a):c.qsa?b.querySelectorAll(a):void 0}:function(a,b){var c,d=[],e=0,f=b.getElementsByTagName(a);if("*"===a){while(c=f[e++])1===c.nodeType&&d.push(c);return d}return f},d.find.CLASS=c.getElementsByClassName&&function(a,b){return p?b.getElementsByClassName(a):void 0},r=[],q=[],(c.qsa=$.test(g.querySelectorAll))&&(jb(function(a){o.appendChild(a).innerHTML="<a id='"+u+"'></a><select id='"+u+"-\f]' msallowcapture=''><option selected=''></option></select>",a.querySelectorAll("[msallowcapture^='']").length&&q.push("[*^$]="+L+"*(?:''|\"\")"),a.querySelectorAll("[selected]").length||q.push("\\["+L+"*(?:value|"+K+")"),a.querySelectorAll("[id~="+u+"-]").length||q.push("~="),a.querySelectorAll(":checked").length||q.push(":checked"),a.querySelectorAll("a#"+u+"+*").length||q.push(".#.+[+~]")}),jb(function(a){var b=g.createElement("input");b.setAttribute("type","hidden"),a.appendChild(b).setAttribute("name","D"),a.querySelectorAll("[name=d]").length&&q.push("name"+L+"*[*^$|!~]?="),a.querySelectorAll(":enabled").length||q.push(":enabled",":disabled"),a.querySelectorAll("*,:x"),q.push(",.*:")})),(c.matchesSelector=$.test(s=o.matches||o.webkitMatchesSelector||o.mozMatchesSelector||o.oMatchesSelector||o.msMatchesSelector))&&jb(function(a){c.disconnectedMatch=s.call(a,"div"),s.call(a,"[s!='']:x"),r.push("!=",P)}),q=q.length&&new RegExp(q.join("|")),r=r.length&&new RegExp(r.join("|")),b=$.test(o.compareDocumentPosition),t=b||$.test(o.contains)?function(a,b){var c=9===a.nodeType?a.documentElement:a,d=b&&b.parentNode;return a===d||!(!d||1!==d.nodeType||!(c.contains?c.contains(d):a.compareDocumentPosition&&16&a.compareDocumentPosition(d)))}:function(a,b){if(b)while(b=b.parentNode)if(b===a)return!0;return!1},B=b?function(a,b){if(a===b)return l=!0,0;var d=!a.compareDocumentPosition-!b.compareDocumentPosition;return d?d:(d=(a.ownerDocument||a)===(b.ownerDocument||b)?a.compareDocumentPosition(b):1,1&d||!c.sortDetached&&b.compareDocumentPosition(a)===d?a===g||a.ownerDocument===v&&t(v,a)?-1:b===g||b.ownerDocument===v&&t(v,b)?1:k?J(k,a)-J(k,b):0:4&d?-1:1)}:function(a,b){if(a===b)return l=!0,0;var c,d=0,e=a.parentNode,f=b.parentNode,h=[a],i=[b];if(!e||!f)return a===g?-1:b===g?1:e?-1:f?1:k?J(k,a)-J(k,b):0;if(e===f)return lb(a,b);c=a;while(c=c.parentNode)h.unshift(c);c=b;while(c=c.parentNode)i.unshift(c);while(h[d]===i[d])d++;return d?lb(h[d],i[d]):h[d]===v?-1:i[d]===v?1:0},g):n},gb.matches=function(a,b){return gb(a,null,null,b)},gb.matchesSelector=function(a,b){if((a.ownerDocument||a)!==n&&m(a),b=b.replace(U,"='$1']"),!(!c.matchesSelector||!p||r&&r.test(b)||q&&q.test(b)))try{var d=s.call(a,b);if(d||c.disconnectedMatch||a.document&&11!==a.document.nodeType)return d}catch(e){}return gb(b,n,null,[a]).length>0},gb.contains=function(a,b){return(a.ownerDocument||a)!==n&&m(a),t(a,b)},gb.attr=function(a,b){(a.ownerDocument||a)!==n&&m(a);var e=d.attrHandle[b.toLowerCase()],f=e&&D.call(d.attrHandle,b.toLowerCase())?e(a,b,!p):void 0;return void 0!==f?f:c.attributes||!p?a.getAttribute(b):(f=a.getAttributeNode(b))&&f.specified?f.value:null},gb.error=function(a){throw new Error("Syntax error, unrecognized expression: "+a)},gb.uniqueSort=function(a){var b,d=[],e=0,f=0;if(l=!c.detectDuplicates,k=!c.sortStable&&a.slice(0),a.sort(B),l){while(b=a[f++])b===a[f]&&(e=d.push(f));while(e--)a.splice(d[e],1)}return k=null,a},e=gb.getText=function(a){var b,c="",d=0,f=a.nodeType;if(f){if(1===f||9===f||11===f){if("string"==typeof a.textContent)return a.textContent;for(a=a.firstChild;a;a=a.nextSibling)c+=e(a)}else if(3===f||4===f)return a.nodeValue}else while(b=a[d++])c+=e(b);return c},d=gb.selectors={cacheLength:50,createPseudo:ib,match:X,attrHandle:{},find:{},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(a){return a[1]=a[1].replace(cb,db),a[3]=(a[3]||a[4]||a[5]||"").replace(cb,db),"~="===a[2]&&(a[3]=" "+a[3]+" "),a.slice(0,4)},CHILD:function(a){return a[1]=a[1].toLowerCase(),"nth"===a[1].slice(0,3)?(a[3]||gb.error(a[0]),a[4]=+(a[4]?a[5]+(a[6]||1):2*("even"===a[3]||"odd"===a[3])),a[5]=+(a[7]+a[8]||"odd"===a[3])):a[3]&&gb.error(a[0]),a},PSEUDO:function(a){var b,c=!a[6]&&a[2];return X.CHILD.test(a[0])?null:(a[3]?a[2]=a[4]||a[5]||"":c&&V.test(c)&&(b=g(c,!0))&&(b=c.indexOf(")",c.length-b)-c.length)&&(a[0]=a[0].slice(0,b),a[2]=c.slice(0,b)),a.slice(0,3))}},filter:{TAG:function(a){var b=a.replace(cb,db).toLowerCase();return"*"===a?function(){return!0}:function(a){return a.nodeName&&a.nodeName.toLowerCase()===b}},CLASS:function(a){var b=y[a+" "];return b||(b=new RegExp("(^|"+L+")"+a+"("+L+"|$)"))&&y(a,function(a){return b.test("string"==typeof a.className&&a.className||"undefined"!=typeof a.getAttribute&&a.getAttribute("class")||"")})},ATTR:function(a,b,c){return function(d){var e=gb.attr(d,a);return null==e?"!="===b:b?(e+="","="===b?e===c:"!="===b?e!==c:"^="===b?c&&0===e.indexOf(c):"*="===b?c&&e.indexOf(c)>-1:"$="===b?c&&e.slice(-c.length)===c:"~="===b?(" "+e.replace(Q," ")+" ").indexOf(c)>-1:"|="===b?e===c||e.slice(0,c.length+1)===c+"-":!1):!0}},CHILD:function(a,b,c,d,e){var f="nth"!==a.slice(0,3),g="last"!==a.slice(-4),h="of-type"===b;return 1===d&&0===e?function(a){return!!a.parentNode}:function(b,c,i){var j,k,l,m,n,o,p=f!==g?"nextSibling":"previousSibling",q=b.parentNode,r=h&&b.nodeName.toLowerCase(),s=!i&&!h;if(q){if(f){while(p){l=b;while(l=l[p])if(h?l.nodeName.toLowerCase()===r:1===l.nodeType)return!1;o=p="only"===a&&!o&&"nextSibling"}return!0}if(o=[g?q.firstChild:q.lastChild],g&&s){k=q[u]||(q[u]={}),j=k[a]||[],n=j[0]===w&&j[1],m=j[0]===w&&j[2],l=n&&q.childNodes[n];while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if(1===l.nodeType&&++m&&l===b){k[a]=[w,n,m];break}}else if(s&&(j=(b[u]||(b[u]={}))[a])&&j[0]===w)m=j[1];else while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if((h?l.nodeName.toLowerCase()===r:1===l.nodeType)&&++m&&(s&&((l[u]||(l[u]={}))[a]=[w,m]),l===b))break;return m-=e,m===d||m%d===0&&m/d>=0}}},PSEUDO:function(a,b){var c,e=d.pseudos[a]||d.setFilters[a.toLowerCase()]||gb.error("unsupported pseudo: "+a);return e[u]?e(b):e.length>1?(c=[a,a,"",b],d.setFilters.hasOwnProperty(a.toLowerCase())?ib(function(a,c){var d,f=e(a,b),g=f.length;while(g--)d=J(a,f[g]),a[d]=!(c[d]=f[g])}):function(a){return e(a,0,c)}):e}},pseudos:{not:ib(function(a){var b=[],c=[],d=h(a.replace(R,"$1"));return d[u]?ib(function(a,b,c,e){var f,g=d(a,null,e,[]),h=a.length;while(h--)(f=g[h])&&(a[h]=!(b[h]=f))}):function(a,e,f){return b[0]=a,d(b,null,f,c),b[0]=null,!c.pop()}}),has:ib(function(a){return function(b){return gb(a,b).length>0}}),contains:ib(function(a){return a=a.replace(cb,db),function(b){return(b.textContent||b.innerText||e(b)).indexOf(a)>-1}}),lang:ib(function(a){return W.test(a||"")||gb.error("unsupported lang: "+a),a=a.replace(cb,db).toLowerCase(),function(b){var c;do if(c=p?b.lang:b.getAttribute("xml:lang")||b.getAttribute("lang"))return c=c.toLowerCase(),c===a||0===c.indexOf(a+"-");while((b=b.parentNode)&&1===b.nodeType);return!1}}),target:function(b){var c=a.location&&a.location.hash;return c&&c.slice(1)===b.id},root:function(a){return a===o},focus:function(a){return a===n.activeElement&&(!n.hasFocus||n.hasFocus())&&!!(a.type||a.href||~a.tabIndex)},enabled:function(a){return a.disabled===!1},disabled:function(a){return a.disabled===!0},checked:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&!!a.checked||"option"===b&&!!a.selected},selected:function(a){return a.parentNode&&a.parentNode.selectedIndex,a.selected===!0},empty:function(a){for(a=a.firstChild;a;a=a.nextSibling)if(a.nodeType<6)return!1;return!0},parent:function(a){return!d.pseudos.empty(a)},header:function(a){return Z.test(a.nodeName)},input:function(a){return Y.test(a.nodeName)},button:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&"button"===a.type||"button"===b},text:function(a){var b;return"input"===a.nodeName.toLowerCase()&&"text"===a.type&&(null==(b=a.getAttribute("type"))||"text"===b.toLowerCase())},first:ob(function(){return[0]}),last:ob(function(a,b){return[b-1]}),eq:ob(function(a,b,c){return[0>c?c+b:c]}),even:ob(function(a,b){for(var c=0;b>c;c+=2)a.push(c);return a}),odd:ob(function(a,b){for(var c=1;b>c;c+=2)a.push(c);return a}),lt:ob(function(a,b,c){for(var d=0>c?c+b:c;--d>=0;)a.push(d);return a}),gt:ob(function(a,b,c){for(var d=0>c?c+b:c;++d<b;)a.push(d);return a})}},d.pseudos.nth=d.pseudos.eq;for(b in{radio:!0,checkbox:!0,file:!0,password:!0,image:!0})d.pseudos[b]=mb(b);for(b in{submit:!0,reset:!0})d.pseudos[b]=nb(b);function qb(){}qb.prototype=d.filters=d.pseudos,d.setFilters=new qb,g=gb.tokenize=function(a,b){var c,e,f,g,h,i,j,k=z[a+" "];if(k)return b?0:k.slice(0);h=a,i=[],j=d.preFilter;while(h){(!c||(e=S.exec(h)))&&(e&&(h=h.slice(e[0].length)||h),i.push(f=[])),c=!1,(e=T.exec(h))&&(c=e.shift(),f.push({value:c,type:e[0].replace(R," ")}),h=h.slice(c.length));for(g in d.filter)!(e=X[g].exec(h))||j[g]&&!(e=j[g](e))||(c=e.shift(),f.push({value:c,type:g,matches:e}),h=h.slice(c.length));if(!c)break}return b?h.length:h?gb.error(a):z(a,i).slice(0)};function rb(a){for(var b=0,c=a.length,d="";c>b;b++)d+=a[b].value;return d}function sb(a,b,c){var d=b.dir,e=c&&"parentNode"===d,f=x++;return b.first?function(b,c,f){while(b=b[d])if(1===b.nodeType||e)return a(b,c,f)}:function(b,c,g){var h,i,j=[w,f];if(g){while(b=b[d])if((1===b.nodeType||e)&&a(b,c,g))return!0}else while(b=b[d])if(1===b.nodeType||e){if(i=b[u]||(b[u]={}),(h=i[d])&&h[0]===w&&h[1]===f)return j[2]=h[2];if(i[d]=j,j[2]=a(b,c,g))return!0}}}function tb(a){return a.length>1?function(b,c,d){var e=a.length;while(e--)if(!a[e](b,c,d))return!1;return!0}:a[0]}function ub(a,b,c){for(var d=0,e=b.length;e>d;d++)gb(a,b[d],c);return c}function vb(a,b,c,d,e){for(var f,g=[],h=0,i=a.length,j=null!=b;i>h;h++)(f=a[h])&&(!c||c(f,d,e))&&(g.push(f),j&&b.push(h));return g}function wb(a,b,c,d,e,f){return d&&!d[u]&&(d=wb(d)),e&&!e[u]&&(e=wb(e,f)),ib(function(f,g,h,i){var j,k,l,m=[],n=[],o=g.length,p=f||ub(b||"*",h.nodeType?[h]:h,[]),q=!a||!f&&b?p:vb(p,m,a,h,i),r=c?e||(f?a:o||d)?[]:g:q;if(c&&c(q,r,h,i),d){j=vb(r,n),d(j,[],h,i),k=j.length;while(k--)(l=j[k])&&(r[n[k]]=!(q[n[k]]=l))}if(f){if(e||a){if(e){j=[],k=r.length;while(k--)(l=r[k])&&j.push(q[k]=l);e(null,r=[],j,i)}k=r.length;while(k--)(l=r[k])&&(j=e?J(f,l):m[k])>-1&&(f[j]=!(g[j]=l))}}else r=vb(r===g?r.splice(o,r.length):r),e?e(null,g,r,i):H.apply(g,r)})}function xb(a){for(var b,c,e,f=a.length,g=d.relative[a[0].type],h=g||d.relative[" "],i=g?1:0,k=sb(function(a){return a===b},h,!0),l=sb(function(a){return J(b,a)>-1},h,!0),m=[function(a,c,d){var e=!g&&(d||c!==j)||((b=c).nodeType?k(a,c,d):l(a,c,d));return b=null,e}];f>i;i++)if(c=d.relative[a[i].type])m=[sb(tb(m),c)];else{if(c=d.filter[a[i].type].apply(null,a[i].matches),c[u]){for(e=++i;f>e;e++)if(d.relative[a[e].type])break;return wb(i>1&&tb(m),i>1&&rb(a.slice(0,i-1).concat({value:" "===a[i-2].type?"*":""})).replace(R,"$1"),c,e>i&&xb(a.slice(i,e)),f>e&&xb(a=a.slice(e)),f>e&&rb(a))}m.push(c)}return tb(m)}function yb(a,b){var c=b.length>0,e=a.length>0,f=function(f,g,h,i,k){var l,m,o,p=0,q="0",r=f&&[],s=[],t=j,u=f||e&&d.find.TAG("*",k),v=w+=null==t?1:Math.random()||.1,x=u.length;for(k&&(j=g!==n&&g);q!==x&&null!=(l=u[q]);q++){if(e&&l){m=0;while(o=a[m++])if(o(l,g,h)){i.push(l);break}k&&(w=v)}c&&((l=!o&&l)&&p--,f&&r.push(l))}if(p+=q,c&&q!==p){m=0;while(o=b[m++])o(r,s,g,h);if(f){if(p>0)while(q--)r[q]||s[q]||(s[q]=F.call(i));s=vb(s)}H.apply(i,s),k&&!f&&s.length>0&&p+b.length>1&&gb.uniqueSort(i)}return k&&(w=v,j=t),r};return c?ib(f):f}return h=gb.compile=function(a,b){var c,d=[],e=[],f=A[a+" "];if(!f){b||(b=g(a)),c=b.length;while(c--)f=xb(b[c]),f[u]?d.push(f):e.push(f);f=A(a,yb(e,d)),f.selector=a}return f},i=gb.select=function(a,b,e,f){var i,j,k,l,m,n="function"==typeof a&&a,o=!f&&g(a=n.selector||a);if(e=e||[],1===o.length){if(j=o[0]=o[0].slice(0),j.length>2&&"ID"===(k=j[0]).type&&c.getById&&9===b.nodeType&&p&&d.relative[j[1].type]){if(b=(d.find.ID(k.matches[0].replace(cb,db),b)||[])[0],!b)return e;n&&(b=b.parentNode),a=a.slice(j.shift().value.length)}i=X.needsContext.test(a)?0:j.length;while(i--){if(k=j[i],d.relative[l=k.type])break;if((m=d.find[l])&&(f=m(k.matches[0].replace(cb,db),ab.test(j[0].type)&&pb(b.parentNode)||b))){if(j.splice(i,1),a=f.length&&rb(j),!a)return H.apply(e,f),e;break}}}return(n||h(a,o))(f,b,!p,e,ab.test(a)&&pb(b.parentNode)||b),e},c.sortStable=u.split("").sort(B).join("")===u,c.detectDuplicates=!!l,m(),c.sortDetached=jb(function(a){return 1&a.compareDocumentPosition(n.createElement("div"))}),jb(function(a){return a.innerHTML="<a href='#'></a>","#"===a.firstChild.getAttribute("href")})||kb("type|href|height|width",function(a,b,c){return c?void 0:a.getAttribute(b,"type"===b.toLowerCase()?1:2)}),c.attributes&&jb(function(a){return a.innerHTML="<input/>",a.firstChild.setAttribute("value",""),""===a.firstChild.getAttribute("value")})||kb("value",function(a,b,c){return c||"input"!==a.nodeName.toLowerCase()?void 0:a.defaultValue}),jb(function(a){return null==a.getAttribute("disabled")})||kb(K,function(a,b,c){var d;return c?void 0:a[b]===!0?b.toLowerCase():(d=a.getAttributeNode(b))&&d.specified?d.value:null}),gb}(a);n.find=t,n.expr=t.selectors,n.expr[":"]=n.expr.pseudos,n.unique=t.uniqueSort,n.text=t.getText,n.isXMLDoc=t.isXML,n.contains=t.contains;var u=n.expr.match.needsContext,v=/^<(\w+)\s*\/?>(?:<\/\1>|)$/,w=/^.[^:#\[\.,]*$/;function x(a,b,c){if(n.isFunction(b))return n.grep(a,function(a,d){return!!b.call(a,d,a)!==c});if(b.nodeType)return n.grep(a,function(a){return a===b!==c});if("string"==typeof b){if(w.test(b))return n.filter(b,a,c);b=n.filter(b,a)}return n.grep(a,function(a){return g.call(b,a)>=0!==c})}n.filter=function(a,b,c){var d=b[0];return c&&(a=":not("+a+")"),1===b.length&&1===d.nodeType?n.find.matchesSelector(d,a)?[d]:[]:n.find.matches(a,n.grep(b,function(a){return 1===a.nodeType}))},n.fn.extend({find:function(a){var b,c=this.length,d=[],e=this;if("string"!=typeof a)return this.pushStack(n(a).filter(function(){for(b=0;c>b;b++)if(n.contains(e[b],this))return!0}));for(b=0;c>b;b++)n.find(a,e[b],d);return d=this.pushStack(c>1?n.unique(d):d),d.selector=this.selector?this.selector+" "+a:a,d},filter:function(a){return this.pushStack(x(this,a||[],!1))},not:function(a){return this.pushStack(x(this,a||[],!0))},is:function(a){return!!x(this,"string"==typeof a&&u.test(a)?n(a):a||[],!1).length}});var y,z=/^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]*))$/,A=n.fn.init=function(a,b){var c,d;if(!a)return this;if("string"==typeof a){if(c="<"===a[0]&&">"===a[a.length-1]&&a.length>=3?[null,a,null]:z.exec(a),!c||!c[1]&&b)return!b||b.jquery?(b||y).find(a):this.constructor(b).find(a);if(c[1]){if(b=b instanceof n?b[0]:b,n.merge(this,n.parseHTML(c[1],b&&b.nodeType?b.ownerDocument||b:l,!0)),v.test(c[1])&&n.isPlainObject(b))for(c in b)n.isFunction(this[c])?this[c](b[c]):this.attr(c,b[c]);return this}return d=l.getElementById(c[2]),d&&d.parentNode&&(this.length=1,this[0]=d),this.context=l,this.selector=a,this}return a.nodeType?(this.context=this[0]=a,this.length=1,this):n.isFunction(a)?"undefined"!=typeof y.ready?y.ready(a):a(n):(void 0!==a.selector&&(this.selector=a.selector,this.context=a.context),n.makeArray(a,this))};A.prototype=n.fn,y=n(l);var B=/^(?:parents|prev(?:Until|All))/,C={children:!0,contents:!0,next:!0,prev:!0};n.extend({dir:function(a,b,c){var d=[],e=void 0!==c;while((a=a[b])&&9!==a.nodeType)if(1===a.nodeType){if(e&&n(a).is(c))break;d.push(a)}return d},sibling:function(a,b){for(var c=[];a;a=a.nextSibling)1===a.nodeType&&a!==b&&c.push(a);return c}}),n.fn.extend({has:function(a){var b=n(a,this),c=b.length;return this.filter(function(){for(var a=0;c>a;a++)if(n.contains(this,b[a]))return!0})},closest:function(a,b){for(var c,d=0,e=this.length,f=[],g=u.test(a)||"string"!=typeof a?n(a,b||this.context):0;e>d;d++)for(c=this[d];c&&c!==b;c=c.parentNode)if(c.nodeType<11&&(g?g.index(c)>-1:1===c.nodeType&&n.find.matchesSelector(c,a))){f.push(c);break}return this.pushStack(f.length>1?n.unique(f):f)},index:function(a){return a?"string"==typeof a?g.call(n(a),this[0]):g.call(this,a.jquery?a[0]:a):this[0]&&this[0].parentNode?this.first().prevAll().length:-1},add:function(a,b){return this.pushStack(n.unique(n.merge(this.get(),n(a,b))))},addBack:function(a){return this.add(null==a?this.prevObject:this.prevObject.filter(a))}});function D(a,b){while((a=a[b])&&1!==a.nodeType);return a}n.each({parent:function(a){var b=a.parentNode;return b&&11!==b.nodeType?b:null},parents:function(a){return n.dir(a,"parentNode")},parentsUntil:function(a,b,c){return n.dir(a,"parentNode",c)},next:function(a){return D(a,"nextSibling")},prev:function(a){return D(a,"previousSibling")},nextAll:function(a){return n.dir(a,"nextSibling")},prevAll:function(a){return n.dir(a,"previousSibling")},nextUntil:function(a,b,c){return n.dir(a,"nextSibling",c)},prevUntil:function(a,b,c){return n.dir(a,"previousSibling",c)},siblings:function(a){return n.sibling((a.parentNode||{}).firstChild,a)},children:function(a){return n.sibling(a.firstChild)},contents:function(a){return a.contentDocument||n.merge([],a.childNodes)}},function(a,b){n.fn[a]=function(c,d){var e=n.map(this,b,c);return"Until"!==a.slice(-5)&&(d=c),d&&"string"==typeof d&&(e=n.filter(d,e)),this.length>1&&(C[a]||n.unique(e),B.test(a)&&e.reverse()),this.pushStack(e)}});var E=/\S+/g,F={};function G(a){var b=F[a]={};return n.each(a.match(E)||[],function(a,c){b[c]=!0}),b}n.Callbacks=function(a){a="string"==typeof a?F[a]||G(a):n.extend({},a);var b,c,d,e,f,g,h=[],i=!a.once&&[],j=function(l){for(b=a.memory&&l,c=!0,g=e||0,e=0,f=h.length,d=!0;h&&f>g;g++)if(h[g].apply(l[0],l[1])===!1&&a.stopOnFalse){b=!1;break}d=!1,h&&(i?i.length&&j(i.shift()):b?h=[]:k.disable())},k={add:function(){if(h){var c=h.length;!function g(b){n.each(b,function(b,c){var d=n.type(c);"function"===d?a.unique&&k.has(c)||h.push(c):c&&c.length&&"string"!==d&&g(c)})}(arguments),d?f=h.length:b&&(e=c,j(b))}return this},remove:function(){return h&&n.each(arguments,function(a,b){var c;while((c=n.inArray(b,h,c))>-1)h.splice(c,1),d&&(f>=c&&f--,g>=c&&g--)}),this},has:function(a){return a?n.inArray(a,h)>-1:!(!h||!h.length)},empty:function(){return h=[],f=0,this},disable:function(){return h=i=b=void 0,this},disabled:function(){return!h},lock:function(){return i=void 0,b||k.disable(),this},locked:function(){return!i},fireWith:function(a,b){return!h||c&&!i||(b=b||[],b=[a,b.slice?b.slice():b],d?i.push(b):j(b)),this},fire:function(){return k.fireWith(this,arguments),this},fired:function(){return!!c}};return k},n.extend({Deferred:function(a){var b=[["resolve","done",n.Callbacks("once memory"),"resolved"],["reject","fail",n.Callbacks("once memory"),"rejected"],["notify","progress",n.Callbacks("memory")]],c="pending",d={state:function(){return c},always:function(){return e.done(arguments).fail(arguments),this},then:function(){var a=arguments;return n.Deferred(function(c){n.each(b,function(b,f){var g=n.isFunction(a[b])&&a[b];e[f[1]](function(){var a=g&&g.apply(this,arguments);a&&n.isFunction(a.promise)?a.promise().done(c.resolve).fail(c.reject).progress(c.notify):c[f[0]+"With"](this===d?c.promise():this,g?[a]:arguments)})}),a=null}).promise()},promise:function(a){return null!=a?n.extend(a,d):d}},e={};return d.pipe=d.then,n.each(b,function(a,f){var g=f[2],h=f[3];d[f[1]]=g.add,h&&g.add(function(){c=h},b[1^a][2].disable,b[2][2].lock),e[f[0]]=function(){return e[f[0]+"With"](this===e?d:this,arguments),this},e[f[0]+"With"]=g.fireWith}),d.promise(e),a&&a.call(e,e),e},when:function(a){var b=0,c=d.call(arguments),e=c.length,f=1!==e||a&&n.isFunction(a.promise)?e:0,g=1===f?a:n.Deferred(),h=function(a,b,c){return function(e){b[a]=this,c[a]=arguments.length>1?d.call(arguments):e,c===i?g.notifyWith(b,c):--f||g.resolveWith(b,c)}},i,j,k;if(e>1)for(i=new Array(e),j=new Array(e),k=new Array(e);e>b;b++)c[b]&&n.isFunction(c[b].promise)?c[b].promise().done(h(b,k,c)).fail(g.reject).progress(h(b,j,i)):--f;return f||g.resolveWith(k,c),g.promise()}});var H;n.fn.ready=function(a){return n.ready.promise().done(a),this},n.extend({isReady:!1,readyWait:1,holdReady:function(a){a?n.readyWait++:n.ready(!0)},ready:function(a){(a===!0?--n.readyWait:n.isReady)||(n.isReady=!0,a!==!0&&--n.readyWait>0||(H.resolveWith(l,[n]),n.fn.triggerHandler&&(n(l).triggerHandler("ready"),n(l).off("ready"))))}});function I(){l.removeEventListener("DOMContentLoaded",I,!1),a.removeEventListener("load",I,!1),n.ready()}n.ready.promise=function(b){return H||(H=n.Deferred(),"complete"===l.readyState?setTimeout(n.ready):(l.addEventListener("DOMContentLoaded",I,!1),a.addEventListener("load",I,!1))),H.promise(b)},n.ready.promise();var J=n.access=function(a,b,c,d,e,f,g){var h=0,i=a.length,j=null==c;if("object"===n.type(c)){e=!0;for(h in c)n.access(a,b,h,c[h],!0,f,g)}else if(void 0!==d&&(e=!0,n.isFunction(d)||(g=!0),j&&(g?(b.call(a,d),b=null):(j=b,b=function(a,b,c){return j.call(n(a),c)})),b))for(;i>h;h++)b(a[h],c,g?d:d.call(a[h],h,b(a[h],c)));return e?a:j?b.call(a):i?b(a[0],c):f};n.acceptData=function(a){return 1===a.nodeType||9===a.nodeType||!+a.nodeType};function K(){Object.defineProperty(this.cache={},0,{get:function(){return{}}}),this.expando=n.expando+K.uid++}K.uid=1,K.accepts=n.acceptData,K.prototype={key:function(a){if(!K.accepts(a))return 0;var b={},c=a[this.expando];if(!c){c=K.uid++;try{b[this.expando]={value:c},Object.defineProperties(a,b)}catch(d){b[this.expando]=c,n.extend(a,b)}}return this.cache[c]||(this.cache[c]={}),c},set:function(a,b,c){var d,e=this.key(a),f=this.cache[e];if("string"==typeof b)f[b]=c;else if(n.isEmptyObject(f))n.extend(this.cache[e],b);else for(d in b)f[d]=b[d];return f},get:function(a,b){var c=this.cache[this.key(a)];return void 0===b?c:c[b]},access:function(a,b,c){var d;return void 0===b||b&&"string"==typeof b&&void 0===c?(d=this.get(a,b),void 0!==d?d:this.get(a,n.camelCase(b))):(this.set(a,b,c),void 0!==c?c:b)},remove:function(a,b){var c,d,e,f=this.key(a),g=this.cache[f];if(void 0===b)this.cache[f]={};else{n.isArray(b)?d=b.concat(b.map(n.camelCase)):(e=n.camelCase(b),b in g?d=[b,e]:(d=e,d=d in g?[d]:d.match(E)||[])),c=d.length;while(c--)delete g[d[c]]}},hasData:function(a){return!n.isEmptyObject(this.cache[a[this.expando]]||{})},discard:function(a){a[this.expando]&&delete this.cache[a[this.expando]]}};var L=new K,M=new K,N=/^(?:\{[\w\W]*\}|\[[\w\W]*\])$/,O=/([A-Z])/g;function P(a,b,c){var d;if(void 0===c&&1===a.nodeType)if(d="data-"+b.replace(O,"-$1").toLowerCase(),c=a.getAttribute(d),"string"==typeof c){try{c="true"===c?!0:"false"===c?!1:"null"===c?null:+c+""===c?+c:N.test(c)?n.parseJSON(c):c}catch(e){}M.set(a,b,c)}else c=void 0;return c}n.extend({hasData:function(a){return M.hasData(a)||L.hasData(a)},data:function(a,b,c){return M.access(a,b,c) -},removeData:function(a,b){M.remove(a,b)},_data:function(a,b,c){return L.access(a,b,c)},_removeData:function(a,b){L.remove(a,b)}}),n.fn.extend({data:function(a,b){var c,d,e,f=this[0],g=f&&f.attributes;if(void 0===a){if(this.length&&(e=M.get(f),1===f.nodeType&&!L.get(f,"hasDataAttrs"))){c=g.length;while(c--)g[c]&&(d=g[c].name,0===d.indexOf("data-")&&(d=n.camelCase(d.slice(5)),P(f,d,e[d])));L.set(f,"hasDataAttrs",!0)}return e}return"object"==typeof a?this.each(function(){M.set(this,a)}):J(this,function(b){var c,d=n.camelCase(a);if(f&&void 0===b){if(c=M.get(f,a),void 0!==c)return c;if(c=M.get(f,d),void 0!==c)return c;if(c=P(f,d,void 0),void 0!==c)return c}else this.each(function(){var c=M.get(this,d);M.set(this,d,b),-1!==a.indexOf("-")&&void 0!==c&&M.set(this,a,b)})},null,b,arguments.length>1,null,!0)},removeData:function(a){return this.each(function(){M.remove(this,a)})}}),n.extend({queue:function(a,b,c){var d;return a?(b=(b||"fx")+"queue",d=L.get(a,b),c&&(!d||n.isArray(c)?d=L.access(a,b,n.makeArray(c)):d.push(c)),d||[]):void 0},dequeue:function(a,b){b=b||"fx";var c=n.queue(a,b),d=c.length,e=c.shift(),f=n._queueHooks(a,b),g=function(){n.dequeue(a,b)};"inprogress"===e&&(e=c.shift(),d--),e&&("fx"===b&&c.unshift("inprogress"),delete f.stop,e.call(a,g,f)),!d&&f&&f.empty.fire()},_queueHooks:function(a,b){var c=b+"queueHooks";return L.get(a,c)||L.access(a,c,{empty:n.Callbacks("once memory").add(function(){L.remove(a,[b+"queue",c])})})}}),n.fn.extend({queue:function(a,b){var c=2;return"string"!=typeof a&&(b=a,a="fx",c--),arguments.length<c?n.queue(this[0],a):void 0===b?this:this.each(function(){var c=n.queue(this,a,b);n._queueHooks(this,a),"fx"===a&&"inprogress"!==c[0]&&n.dequeue(this,a)})},dequeue:function(a){return this.each(function(){n.dequeue(this,a)})},clearQueue:function(a){return this.queue(a||"fx",[])},promise:function(a,b){var c,d=1,e=n.Deferred(),f=this,g=this.length,h=function(){--d||e.resolveWith(f,[f])};"string"!=typeof a&&(b=a,a=void 0),a=a||"fx";while(g--)c=L.get(f[g],a+"queueHooks"),c&&c.empty&&(d++,c.empty.add(h));return h(),e.promise(b)}});var Q=/[+-]?(?:\d*\.|)\d+(?:[eE][+-]?\d+|)/.source,R=["Top","Right","Bottom","Left"],S=function(a,b){return a=b||a,"none"===n.css(a,"display")||!n.contains(a.ownerDocument,a)},T=/^(?:checkbox|radio)$/i;!function(){var a=l.createDocumentFragment(),b=a.appendChild(l.createElement("div")),c=l.createElement("input");c.setAttribute("type","radio"),c.setAttribute("checked","checked"),c.setAttribute("name","t"),b.appendChild(c),k.checkClone=b.cloneNode(!0).cloneNode(!0).lastChild.checked,b.innerHTML="<textarea>x</textarea>",k.noCloneChecked=!!b.cloneNode(!0).lastChild.defaultValue}();var U="undefined";k.focusinBubbles="onfocusin"in a;var V=/^key/,W=/^(?:mouse|pointer|contextmenu)|click/,X=/^(?:focusinfocus|focusoutblur)$/,Y=/^([^.]*)(?:\.(.+)|)$/;function Z(){return!0}function $(){return!1}function _(){try{return l.activeElement}catch(a){}}n.event={global:{},add:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=L.get(a);if(r){c.handler&&(f=c,c=f.handler,e=f.selector),c.guid||(c.guid=n.guid++),(i=r.events)||(i=r.events={}),(g=r.handle)||(g=r.handle=function(b){return typeof n!==U&&n.event.triggered!==b.type?n.event.dispatch.apply(a,arguments):void 0}),b=(b||"").match(E)||[""],j=b.length;while(j--)h=Y.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o&&(l=n.event.special[o]||{},o=(e?l.delegateType:l.bindType)||o,l=n.event.special[o]||{},k=n.extend({type:o,origType:q,data:d,handler:c,guid:c.guid,selector:e,needsContext:e&&n.expr.match.needsContext.test(e),namespace:p.join(".")},f),(m=i[o])||(m=i[o]=[],m.delegateCount=0,l.setup&&l.setup.call(a,d,p,g)!==!1||a.addEventListener&&a.addEventListener(o,g,!1)),l.add&&(l.add.call(a,k),k.handler.guid||(k.handler.guid=c.guid)),e?m.splice(m.delegateCount++,0,k):m.push(k),n.event.global[o]=!0)}},remove:function(a,b,c,d,e){var f,g,h,i,j,k,l,m,o,p,q,r=L.hasData(a)&&L.get(a);if(r&&(i=r.events)){b=(b||"").match(E)||[""],j=b.length;while(j--)if(h=Y.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o){l=n.event.special[o]||{},o=(d?l.delegateType:l.bindType)||o,m=i[o]||[],h=h[2]&&new RegExp("(^|\\.)"+p.join("\\.(?:.*\\.|)")+"(\\.|$)"),g=f=m.length;while(f--)k=m[f],!e&&q!==k.origType||c&&c.guid!==k.guid||h&&!h.test(k.namespace)||d&&d!==k.selector&&("**"!==d||!k.selector)||(m.splice(f,1),k.selector&&m.delegateCount--,l.remove&&l.remove.call(a,k));g&&!m.length&&(l.teardown&&l.teardown.call(a,p,r.handle)!==!1||n.removeEvent(a,o,r.handle),delete i[o])}else for(o in i)n.event.remove(a,o+b[j],c,d,!0);n.isEmptyObject(i)&&(delete r.handle,L.remove(a,"events"))}},trigger:function(b,c,d,e){var f,g,h,i,k,m,o,p=[d||l],q=j.call(b,"type")?b.type:b,r=j.call(b,"namespace")?b.namespace.split("."):[];if(g=h=d=d||l,3!==d.nodeType&&8!==d.nodeType&&!X.test(q+n.event.triggered)&&(q.indexOf(".")>=0&&(r=q.split("."),q=r.shift(),r.sort()),k=q.indexOf(":")<0&&"on"+q,b=b[n.expando]?b:new n.Event(q,"object"==typeof b&&b),b.isTrigger=e?2:3,b.namespace=r.join("."),b.namespace_re=b.namespace?new RegExp("(^|\\.)"+r.join("\\.(?:.*\\.|)")+"(\\.|$)"):null,b.result=void 0,b.target||(b.target=d),c=null==c?[b]:n.makeArray(c,[b]),o=n.event.special[q]||{},e||!o.trigger||o.trigger.apply(d,c)!==!1)){if(!e&&!o.noBubble&&!n.isWindow(d)){for(i=o.delegateType||q,X.test(i+q)||(g=g.parentNode);g;g=g.parentNode)p.push(g),h=g;h===(d.ownerDocument||l)&&p.push(h.defaultView||h.parentWindow||a)}f=0;while((g=p[f++])&&!b.isPropagationStopped())b.type=f>1?i:o.bindType||q,m=(L.get(g,"events")||{})[b.type]&&L.get(g,"handle"),m&&m.apply(g,c),m=k&&g[k],m&&m.apply&&n.acceptData(g)&&(b.result=m.apply(g,c),b.result===!1&&b.preventDefault());return b.type=q,e||b.isDefaultPrevented()||o._default&&o._default.apply(p.pop(),c)!==!1||!n.acceptData(d)||k&&n.isFunction(d[q])&&!n.isWindow(d)&&(h=d[k],h&&(d[k]=null),n.event.triggered=q,d[q](),n.event.triggered=void 0,h&&(d[k]=h)),b.result}},dispatch:function(a){a=n.event.fix(a);var b,c,e,f,g,h=[],i=d.call(arguments),j=(L.get(this,"events")||{})[a.type]||[],k=n.event.special[a.type]||{};if(i[0]=a,a.delegateTarget=this,!k.preDispatch||k.preDispatch.call(this,a)!==!1){h=n.event.handlers.call(this,a,j),b=0;while((f=h[b++])&&!a.isPropagationStopped()){a.currentTarget=f.elem,c=0;while((g=f.handlers[c++])&&!a.isImmediatePropagationStopped())(!a.namespace_re||a.namespace_re.test(g.namespace))&&(a.handleObj=g,a.data=g.data,e=((n.event.special[g.origType]||{}).handle||g.handler).apply(f.elem,i),void 0!==e&&(a.result=e)===!1&&(a.preventDefault(),a.stopPropagation()))}return k.postDispatch&&k.postDispatch.call(this,a),a.result}},handlers:function(a,b){var c,d,e,f,g=[],h=b.delegateCount,i=a.target;if(h&&i.nodeType&&(!a.button||"click"!==a.type))for(;i!==this;i=i.parentNode||this)if(i.disabled!==!0||"click"!==a.type){for(d=[],c=0;h>c;c++)f=b[c],e=f.selector+" ",void 0===d[e]&&(d[e]=f.needsContext?n(e,this).index(i)>=0:n.find(e,this,null,[i]).length),d[e]&&d.push(f);d.length&&g.push({elem:i,handlers:d})}return h<b.length&&g.push({elem:this,handlers:b.slice(h)}),g},props:"altKey bubbles cancelable ctrlKey currentTarget eventPhase metaKey relatedTarget shiftKey target timeStamp view which".split(" "),fixHooks:{},keyHooks:{props:"char charCode key keyCode".split(" "),filter:function(a,b){return null==a.which&&(a.which=null!=b.charCode?b.charCode:b.keyCode),a}},mouseHooks:{props:"button buttons clientX clientY offsetX offsetY pageX pageY screenX screenY toElement".split(" "),filter:function(a,b){var c,d,e,f=b.button;return null==a.pageX&&null!=b.clientX&&(c=a.target.ownerDocument||l,d=c.documentElement,e=c.body,a.pageX=b.clientX+(d&&d.scrollLeft||e&&e.scrollLeft||0)-(d&&d.clientLeft||e&&e.clientLeft||0),a.pageY=b.clientY+(d&&d.scrollTop||e&&e.scrollTop||0)-(d&&d.clientTop||e&&e.clientTop||0)),a.which||void 0===f||(a.which=1&f?1:2&f?3:4&f?2:0),a}},fix:function(a){if(a[n.expando])return a;var b,c,d,e=a.type,f=a,g=this.fixHooks[e];g||(this.fixHooks[e]=g=W.test(e)?this.mouseHooks:V.test(e)?this.keyHooks:{}),d=g.props?this.props.concat(g.props):this.props,a=new n.Event(f),b=d.length;while(b--)c=d[b],a[c]=f[c];return a.target||(a.target=l),3===a.target.nodeType&&(a.target=a.target.parentNode),g.filter?g.filter(a,f):a},special:{load:{noBubble:!0},focus:{trigger:function(){return this!==_()&&this.focus?(this.focus(),!1):void 0},delegateType:"focusin"},blur:{trigger:function(){return this===_()&&this.blur?(this.blur(),!1):void 0},delegateType:"focusout"},click:{trigger:function(){return"checkbox"===this.type&&this.click&&n.nodeName(this,"input")?(this.click(),!1):void 0},_default:function(a){return n.nodeName(a.target,"a")}},beforeunload:{postDispatch:function(a){void 0!==a.result&&a.originalEvent&&(a.originalEvent.returnValue=a.result)}}},simulate:function(a,b,c,d){var e=n.extend(new n.Event,c,{type:a,isSimulated:!0,originalEvent:{}});d?n.event.trigger(e,null,b):n.event.dispatch.call(b,e),e.isDefaultPrevented()&&c.preventDefault()}},n.removeEvent=function(a,b,c){a.removeEventListener&&a.removeEventListener(b,c,!1)},n.Event=function(a,b){return this instanceof n.Event?(a&&a.type?(this.originalEvent=a,this.type=a.type,this.isDefaultPrevented=a.defaultPrevented||void 0===a.defaultPrevented&&a.returnValue===!1?Z:$):this.type=a,b&&n.extend(this,b),this.timeStamp=a&&a.timeStamp||n.now(),void(this[n.expando]=!0)):new n.Event(a,b)},n.Event.prototype={isDefaultPrevented:$,isPropagationStopped:$,isImmediatePropagationStopped:$,preventDefault:function(){var a=this.originalEvent;this.isDefaultPrevented=Z,a&&a.preventDefault&&a.preventDefault()},stopPropagation:function(){var a=this.originalEvent;this.isPropagationStopped=Z,a&&a.stopPropagation&&a.stopPropagation()},stopImmediatePropagation:function(){var a=this.originalEvent;this.isImmediatePropagationStopped=Z,a&&a.stopImmediatePropagation&&a.stopImmediatePropagation(),this.stopPropagation()}},n.each({mouseenter:"mouseover",mouseleave:"mouseout",pointerenter:"pointerover",pointerleave:"pointerout"},function(a,b){n.event.special[a]={delegateType:b,bindType:b,handle:function(a){var c,d=this,e=a.relatedTarget,f=a.handleObj;return(!e||e!==d&&!n.contains(d,e))&&(a.type=f.origType,c=f.handler.apply(this,arguments),a.type=b),c}}}),k.focusinBubbles||n.each({focus:"focusin",blur:"focusout"},function(a,b){var c=function(a){n.event.simulate(b,a.target,n.event.fix(a),!0)};n.event.special[b]={setup:function(){var d=this.ownerDocument||this,e=L.access(d,b);e||d.addEventListener(a,c,!0),L.access(d,b,(e||0)+1)},teardown:function(){var d=this.ownerDocument||this,e=L.access(d,b)-1;e?L.access(d,b,e):(d.removeEventListener(a,c,!0),L.remove(d,b))}}}),n.fn.extend({on:function(a,b,c,d,e){var f,g;if("object"==typeof a){"string"!=typeof b&&(c=c||b,b=void 0);for(g in a)this.on(g,b,c,a[g],e);return this}if(null==c&&null==d?(d=b,c=b=void 0):null==d&&("string"==typeof b?(d=c,c=void 0):(d=c,c=b,b=void 0)),d===!1)d=$;else if(!d)return this;return 1===e&&(f=d,d=function(a){return n().off(a),f.apply(this,arguments)},d.guid=f.guid||(f.guid=n.guid++)),this.each(function(){n.event.add(this,a,d,c,b)})},one:function(a,b,c,d){return this.on(a,b,c,d,1)},off:function(a,b,c){var d,e;if(a&&a.preventDefault&&a.handleObj)return d=a.handleObj,n(a.delegateTarget).off(d.namespace?d.origType+"."+d.namespace:d.origType,d.selector,d.handler),this;if("object"==typeof a){for(e in a)this.off(e,b,a[e]);return this}return(b===!1||"function"==typeof b)&&(c=b,b=void 0),c===!1&&(c=$),this.each(function(){n.event.remove(this,a,c,b)})},trigger:function(a,b){return this.each(function(){n.event.trigger(a,b,this)})},triggerHandler:function(a,b){var c=this[0];return c?n.event.trigger(a,b,c,!0):void 0}});var ab=/<(?!area|br|col|embed|hr|img|input|link|meta|param)(([\w:]+)[^>]*)\/>/gi,bb=/<([\w:]+)/,cb=/<|&#?\w+;/,db=/<(?:script|style|link)/i,eb=/checked\s*(?:[^=]|=\s*.checked.)/i,fb=/^$|\/(?:java|ecma)script/i,gb=/^true\/(.*)/,hb=/^\s*<!(?:\[CDATA\[|--)|(?:\]\]|--)>\s*$/g,ib={option:[1,"<select multiple='multiple'>","</select>"],thead:[1,"<table>","</table>"],col:[2,"<table><colgroup>","</colgroup></table>"],tr:[2,"<table><tbody>","</tbody></table>"],td:[3,"<table><tbody><tr>","</tr></tbody></table>"],_default:[0,"",""]};ib.optgroup=ib.option,ib.tbody=ib.tfoot=ib.colgroup=ib.caption=ib.thead,ib.th=ib.td;function jb(a,b){return n.nodeName(a,"table")&&n.nodeName(11!==b.nodeType?b:b.firstChild,"tr")?a.getElementsByTagName("tbody")[0]||a.appendChild(a.ownerDocument.createElement("tbody")):a}function kb(a){return a.type=(null!==a.getAttribute("type"))+"/"+a.type,a}function lb(a){var b=gb.exec(a.type);return b?a.type=b[1]:a.removeAttribute("type"),a}function mb(a,b){for(var c=0,d=a.length;d>c;c++)L.set(a[c],"globalEval",!b||L.get(b[c],"globalEval"))}function nb(a,b){var c,d,e,f,g,h,i,j;if(1===b.nodeType){if(L.hasData(a)&&(f=L.access(a),g=L.set(b,f),j=f.events)){delete g.handle,g.events={};for(e in j)for(c=0,d=j[e].length;d>c;c++)n.event.add(b,e,j[e][c])}M.hasData(a)&&(h=M.access(a),i=n.extend({},h),M.set(b,i))}}function ob(a,b){var c=a.getElementsByTagName?a.getElementsByTagName(b||"*"):a.querySelectorAll?a.querySelectorAll(b||"*"):[];return void 0===b||b&&n.nodeName(a,b)?n.merge([a],c):c}function pb(a,b){var c=b.nodeName.toLowerCase();"input"===c&&T.test(a.type)?b.checked=a.checked:("input"===c||"textarea"===c)&&(b.defaultValue=a.defaultValue)}n.extend({clone:function(a,b,c){var d,e,f,g,h=a.cloneNode(!0),i=n.contains(a.ownerDocument,a);if(!(k.noCloneChecked||1!==a.nodeType&&11!==a.nodeType||n.isXMLDoc(a)))for(g=ob(h),f=ob(a),d=0,e=f.length;e>d;d++)pb(f[d],g[d]);if(b)if(c)for(f=f||ob(a),g=g||ob(h),d=0,e=f.length;e>d;d++)nb(f[d],g[d]);else nb(a,h);return g=ob(h,"script"),g.length>0&&mb(g,!i&&ob(a,"script")),h},buildFragment:function(a,b,c,d){for(var e,f,g,h,i,j,k=b.createDocumentFragment(),l=[],m=0,o=a.length;o>m;m++)if(e=a[m],e||0===e)if("object"===n.type(e))n.merge(l,e.nodeType?[e]:e);else if(cb.test(e)){f=f||k.appendChild(b.createElement("div")),g=(bb.exec(e)||["",""])[1].toLowerCase(),h=ib[g]||ib._default,f.innerHTML=h[1]+e.replace(ab,"<$1></$2>")+h[2],j=h[0];while(j--)f=f.lastChild;n.merge(l,f.childNodes),f=k.firstChild,f.textContent=""}else l.push(b.createTextNode(e));k.textContent="",m=0;while(e=l[m++])if((!d||-1===n.inArray(e,d))&&(i=n.contains(e.ownerDocument,e),f=ob(k.appendChild(e),"script"),i&&mb(f),c)){j=0;while(e=f[j++])fb.test(e.type||"")&&c.push(e)}return k},cleanData:function(a){for(var b,c,d,e,f=n.event.special,g=0;void 0!==(c=a[g]);g++){if(n.acceptData(c)&&(e=c[L.expando],e&&(b=L.cache[e]))){if(b.events)for(d in b.events)f[d]?n.event.remove(c,d):n.removeEvent(c,d,b.handle);L.cache[e]&&delete L.cache[e]}delete M.cache[c[M.expando]]}}}),n.fn.extend({text:function(a){return J(this,function(a){return void 0===a?n.text(this):this.empty().each(function(){(1===this.nodeType||11===this.nodeType||9===this.nodeType)&&(this.textContent=a)})},null,a,arguments.length)},append:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=jb(this,a);b.appendChild(a)}})},prepend:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=jb(this,a);b.insertBefore(a,b.firstChild)}})},before:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this)})},after:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this.nextSibling)})},remove:function(a,b){for(var c,d=a?n.filter(a,this):this,e=0;null!=(c=d[e]);e++)b||1!==c.nodeType||n.cleanData(ob(c)),c.parentNode&&(b&&n.contains(c.ownerDocument,c)&&mb(ob(c,"script")),c.parentNode.removeChild(c));return this},empty:function(){for(var a,b=0;null!=(a=this[b]);b++)1===a.nodeType&&(n.cleanData(ob(a,!1)),a.textContent="");return this},clone:function(a,b){return a=null==a?!1:a,b=null==b?a:b,this.map(function(){return n.clone(this,a,b)})},html:function(a){return J(this,function(a){var b=this[0]||{},c=0,d=this.length;if(void 0===a&&1===b.nodeType)return b.innerHTML;if("string"==typeof a&&!db.test(a)&&!ib[(bb.exec(a)||["",""])[1].toLowerCase()]){a=a.replace(ab,"<$1></$2>");try{for(;d>c;c++)b=this[c]||{},1===b.nodeType&&(n.cleanData(ob(b,!1)),b.innerHTML=a);b=0}catch(e){}}b&&this.empty().append(a)},null,a,arguments.length)},replaceWith:function(){var a=arguments[0];return this.domManip(arguments,function(b){a=this.parentNode,n.cleanData(ob(this)),a&&a.replaceChild(b,this)}),a&&(a.length||a.nodeType)?this:this.remove()},detach:function(a){return this.remove(a,!0)},domManip:function(a,b){a=e.apply([],a);var c,d,f,g,h,i,j=0,l=this.length,m=this,o=l-1,p=a[0],q=n.isFunction(p);if(q||l>1&&"string"==typeof p&&!k.checkClone&&eb.test(p))return this.each(function(c){var d=m.eq(c);q&&(a[0]=p.call(this,c,d.html())),d.domManip(a,b)});if(l&&(c=n.buildFragment(a,this[0].ownerDocument,!1,this),d=c.firstChild,1===c.childNodes.length&&(c=d),d)){for(f=n.map(ob(c,"script"),kb),g=f.length;l>j;j++)h=c,j!==o&&(h=n.clone(h,!0,!0),g&&n.merge(f,ob(h,"script"))),b.call(this[j],h,j);if(g)for(i=f[f.length-1].ownerDocument,n.map(f,lb),j=0;g>j;j++)h=f[j],fb.test(h.type||"")&&!L.access(h,"globalEval")&&n.contains(i,h)&&(h.src?n._evalUrl&&n._evalUrl(h.src):n.globalEval(h.textContent.replace(hb,"")))}return this}}),n.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(a,b){n.fn[a]=function(a){for(var c,d=[],e=n(a),g=e.length-1,h=0;g>=h;h++)c=h===g?this:this.clone(!0),n(e[h])[b](c),f.apply(d,c.get());return this.pushStack(d)}});var qb,rb={};function sb(b,c){var d,e=n(c.createElement(b)).appendTo(c.body),f=a.getDefaultComputedStyle&&(d=a.getDefaultComputedStyle(e[0]))?d.display:n.css(e[0],"display");return e.detach(),f}function tb(a){var b=l,c=rb[a];return c||(c=sb(a,b),"none"!==c&&c||(qb=(qb||n("<iframe frameborder='0' width='0' height='0'/>")).appendTo(b.documentElement),b=qb[0].contentDocument,b.write(),b.close(),c=sb(a,b),qb.detach()),rb[a]=c),c}var ub=/^margin/,vb=new RegExp("^("+Q+")(?!px)[a-z%]+$","i"),wb=function(b){return b.ownerDocument.defaultView.opener?b.ownerDocument.defaultView.getComputedStyle(b,null):a.getComputedStyle(b,null)};function xb(a,b,c){var d,e,f,g,h=a.style;return c=c||wb(a),c&&(g=c.getPropertyValue(b)||c[b]),c&&(""!==g||n.contains(a.ownerDocument,a)||(g=n.style(a,b)),vb.test(g)&&ub.test(b)&&(d=h.width,e=h.minWidth,f=h.maxWidth,h.minWidth=h.maxWidth=h.width=g,g=c.width,h.width=d,h.minWidth=e,h.maxWidth=f)),void 0!==g?g+"":g}function yb(a,b){return{get:function(){return a()?void delete this.get:(this.get=b).apply(this,arguments)}}}!function(){var b,c,d=l.documentElement,e=l.createElement("div"),f=l.createElement("div");if(f.style){f.style.backgroundClip="content-box",f.cloneNode(!0).style.backgroundClip="",k.clearCloneStyle="content-box"===f.style.backgroundClip,e.style.cssText="border:0;width:0;height:0;top:0;left:-9999px;margin-top:1px;position:absolute",e.appendChild(f);function g(){f.style.cssText="-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;display:block;margin-top:1%;top:1%;border:1px;padding:1px;width:4px;position:absolute",f.innerHTML="",d.appendChild(e);var g=a.getComputedStyle(f,null);b="1%"!==g.top,c="4px"===g.width,d.removeChild(e)}a.getComputedStyle&&n.extend(k,{pixelPosition:function(){return g(),b},boxSizingReliable:function(){return null==c&&g(),c},reliableMarginRight:function(){var b,c=f.appendChild(l.createElement("div"));return c.style.cssText=f.style.cssText="-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box;display:block;margin:0;border:0;padding:0",c.style.marginRight=c.style.width="0",f.style.width="1px",d.appendChild(e),b=!parseFloat(a.getComputedStyle(c,null).marginRight),d.removeChild(e),f.removeChild(c),b}})}}(),n.swap=function(a,b,c,d){var e,f,g={};for(f in b)g[f]=a.style[f],a.style[f]=b[f];e=c.apply(a,d||[]);for(f in b)a.style[f]=g[f];return e};var zb=/^(none|table(?!-c[ea]).+)/,Ab=new RegExp("^("+Q+")(.*)$","i"),Bb=new RegExp("^([+-])=("+Q+")","i"),Cb={position:"absolute",visibility:"hidden",display:"block"},Db={letterSpacing:"0",fontWeight:"400"},Eb=["Webkit","O","Moz","ms"];function Fb(a,b){if(b in a)return b;var c=b[0].toUpperCase()+b.slice(1),d=b,e=Eb.length;while(e--)if(b=Eb[e]+c,b in a)return b;return d}function Gb(a,b,c){var d=Ab.exec(b);return d?Math.max(0,d[1]-(c||0))+(d[2]||"px"):b}function Hb(a,b,c,d,e){for(var f=c===(d?"border":"content")?4:"width"===b?1:0,g=0;4>f;f+=2)"margin"===c&&(g+=n.css(a,c+R[f],!0,e)),d?("content"===c&&(g-=n.css(a,"padding"+R[f],!0,e)),"margin"!==c&&(g-=n.css(a,"border"+R[f]+"Width",!0,e))):(g+=n.css(a,"padding"+R[f],!0,e),"padding"!==c&&(g+=n.css(a,"border"+R[f]+"Width",!0,e)));return g}function Ib(a,b,c){var d=!0,e="width"===b?a.offsetWidth:a.offsetHeight,f=wb(a),g="border-box"===n.css(a,"boxSizing",!1,f);if(0>=e||null==e){if(e=xb(a,b,f),(0>e||null==e)&&(e=a.style[b]),vb.test(e))return e;d=g&&(k.boxSizingReliable()||e===a.style[b]),e=parseFloat(e)||0}return e+Hb(a,b,c||(g?"border":"content"),d,f)+"px"}function Jb(a,b){for(var c,d,e,f=[],g=0,h=a.length;h>g;g++)d=a[g],d.style&&(f[g]=L.get(d,"olddisplay"),c=d.style.display,b?(f[g]||"none"!==c||(d.style.display=""),""===d.style.display&&S(d)&&(f[g]=L.access(d,"olddisplay",tb(d.nodeName)))):(e=S(d),"none"===c&&e||L.set(d,"olddisplay",e?c:n.css(d,"display"))));for(g=0;h>g;g++)d=a[g],d.style&&(b&&"none"!==d.style.display&&""!==d.style.display||(d.style.display=b?f[g]||"":"none"));return a}n.extend({cssHooks:{opacity:{get:function(a,b){if(b){var c=xb(a,"opacity");return""===c?"1":c}}}},cssNumber:{columnCount:!0,fillOpacity:!0,flexGrow:!0,flexShrink:!0,fontWeight:!0,lineHeight:!0,opacity:!0,order:!0,orphans:!0,widows:!0,zIndex:!0,zoom:!0},cssProps:{"float":"cssFloat"},style:function(a,b,c,d){if(a&&3!==a.nodeType&&8!==a.nodeType&&a.style){var e,f,g,h=n.camelCase(b),i=a.style;return b=n.cssProps[h]||(n.cssProps[h]=Fb(i,h)),g=n.cssHooks[b]||n.cssHooks[h],void 0===c?g&&"get"in g&&void 0!==(e=g.get(a,!1,d))?e:i[b]:(f=typeof c,"string"===f&&(e=Bb.exec(c))&&(c=(e[1]+1)*e[2]+parseFloat(n.css(a,b)),f="number"),null!=c&&c===c&&("number"!==f||n.cssNumber[h]||(c+="px"),k.clearCloneStyle||""!==c||0!==b.indexOf("background")||(i[b]="inherit"),g&&"set"in g&&void 0===(c=g.set(a,c,d))||(i[b]=c)),void 0)}},css:function(a,b,c,d){var e,f,g,h=n.camelCase(b);return b=n.cssProps[h]||(n.cssProps[h]=Fb(a.style,h)),g=n.cssHooks[b]||n.cssHooks[h],g&&"get"in g&&(e=g.get(a,!0,c)),void 0===e&&(e=xb(a,b,d)),"normal"===e&&b in Db&&(e=Db[b]),""===c||c?(f=parseFloat(e),c===!0||n.isNumeric(f)?f||0:e):e}}),n.each(["height","width"],function(a,b){n.cssHooks[b]={get:function(a,c,d){return c?zb.test(n.css(a,"display"))&&0===a.offsetWidth?n.swap(a,Cb,function(){return Ib(a,b,d)}):Ib(a,b,d):void 0},set:function(a,c,d){var e=d&&wb(a);return Gb(a,c,d?Hb(a,b,d,"border-box"===n.css(a,"boxSizing",!1,e),e):0)}}}),n.cssHooks.marginRight=yb(k.reliableMarginRight,function(a,b){return b?n.swap(a,{display:"inline-block"},xb,[a,"marginRight"]):void 0}),n.each({margin:"",padding:"",border:"Width"},function(a,b){n.cssHooks[a+b]={expand:function(c){for(var d=0,e={},f="string"==typeof c?c.split(" "):[c];4>d;d++)e[a+R[d]+b]=f[d]||f[d-2]||f[0];return e}},ub.test(a)||(n.cssHooks[a+b].set=Gb)}),n.fn.extend({css:function(a,b){return J(this,function(a,b,c){var d,e,f={},g=0;if(n.isArray(b)){for(d=wb(a),e=b.length;e>g;g++)f[b[g]]=n.css(a,b[g],!1,d);return f}return void 0!==c?n.style(a,b,c):n.css(a,b)},a,b,arguments.length>1)},show:function(){return Jb(this,!0)},hide:function(){return Jb(this)},toggle:function(a){return"boolean"==typeof a?a?this.show():this.hide():this.each(function(){S(this)?n(this).show():n(this).hide()})}});function Kb(a,b,c,d,e){return new Kb.prototype.init(a,b,c,d,e)}n.Tween=Kb,Kb.prototype={constructor:Kb,init:function(a,b,c,d,e,f){this.elem=a,this.prop=c,this.easing=e||"swing",this.options=b,this.start=this.now=this.cur(),this.end=d,this.unit=f||(n.cssNumber[c]?"":"px")},cur:function(){var a=Kb.propHooks[this.prop];return a&&a.get?a.get(this):Kb.propHooks._default.get(this)},run:function(a){var b,c=Kb.propHooks[this.prop];return this.pos=b=this.options.duration?n.easing[this.easing](a,this.options.duration*a,0,1,this.options.duration):a,this.now=(this.end-this.start)*b+this.start,this.options.step&&this.options.step.call(this.elem,this.now,this),c&&c.set?c.set(this):Kb.propHooks._default.set(this),this}},Kb.prototype.init.prototype=Kb.prototype,Kb.propHooks={_default:{get:function(a){var b;return null==a.elem[a.prop]||a.elem.style&&null!=a.elem.style[a.prop]?(b=n.css(a.elem,a.prop,""),b&&"auto"!==b?b:0):a.elem[a.prop]},set:function(a){n.fx.step[a.prop]?n.fx.step[a.prop](a):a.elem.style&&(null!=a.elem.style[n.cssProps[a.prop]]||n.cssHooks[a.prop])?n.style(a.elem,a.prop,a.now+a.unit):a.elem[a.prop]=a.now}}},Kb.propHooks.scrollTop=Kb.propHooks.scrollLeft={set:function(a){a.elem.nodeType&&a.elem.parentNode&&(a.elem[a.prop]=a.now)}},n.easing={linear:function(a){return a},swing:function(a){return.5-Math.cos(a*Math.PI)/2}},n.fx=Kb.prototype.init,n.fx.step={};var Lb,Mb,Nb=/^(?:toggle|show|hide)$/,Ob=new RegExp("^(?:([+-])=|)("+Q+")([a-z%]*)$","i"),Pb=/queueHooks$/,Qb=[Vb],Rb={"*":[function(a,b){var c=this.createTween(a,b),d=c.cur(),e=Ob.exec(b),f=e&&e[3]||(n.cssNumber[a]?"":"px"),g=(n.cssNumber[a]||"px"!==f&&+d)&&Ob.exec(n.css(c.elem,a)),h=1,i=20;if(g&&g[3]!==f){f=f||g[3],e=e||[],g=+d||1;do h=h||".5",g/=h,n.style(c.elem,a,g+f);while(h!==(h=c.cur()/d)&&1!==h&&--i)}return e&&(g=c.start=+g||+d||0,c.unit=f,c.end=e[1]?g+(e[1]+1)*e[2]:+e[2]),c}]};function Sb(){return setTimeout(function(){Lb=void 0}),Lb=n.now()}function Tb(a,b){var c,d=0,e={height:a};for(b=b?1:0;4>d;d+=2-b)c=R[d],e["margin"+c]=e["padding"+c]=a;return b&&(e.opacity=e.width=a),e}function Ub(a,b,c){for(var d,e=(Rb[b]||[]).concat(Rb["*"]),f=0,g=e.length;g>f;f++)if(d=e[f].call(c,b,a))return d}function Vb(a,b,c){var d,e,f,g,h,i,j,k,l=this,m={},o=a.style,p=a.nodeType&&S(a),q=L.get(a,"fxshow");c.queue||(h=n._queueHooks(a,"fx"),null==h.unqueued&&(h.unqueued=0,i=h.empty.fire,h.empty.fire=function(){h.unqueued||i()}),h.unqueued++,l.always(function(){l.always(function(){h.unqueued--,n.queue(a,"fx").length||h.empty.fire()})})),1===a.nodeType&&("height"in b||"width"in b)&&(c.overflow=[o.overflow,o.overflowX,o.overflowY],j=n.css(a,"display"),k="none"===j?L.get(a,"olddisplay")||tb(a.nodeName):j,"inline"===k&&"none"===n.css(a,"float")&&(o.display="inline-block")),c.overflow&&(o.overflow="hidden",l.always(function(){o.overflow=c.overflow[0],o.overflowX=c.overflow[1],o.overflowY=c.overflow[2]}));for(d in b)if(e=b[d],Nb.exec(e)){if(delete b[d],f=f||"toggle"===e,e===(p?"hide":"show")){if("show"!==e||!q||void 0===q[d])continue;p=!0}m[d]=q&&q[d]||n.style(a,d)}else j=void 0;if(n.isEmptyObject(m))"inline"===("none"===j?tb(a.nodeName):j)&&(o.display=j);else{q?"hidden"in q&&(p=q.hidden):q=L.access(a,"fxshow",{}),f&&(q.hidden=!p),p?n(a).show():l.done(function(){n(a).hide()}),l.done(function(){var b;L.remove(a,"fxshow");for(b in m)n.style(a,b,m[b])});for(d in m)g=Ub(p?q[d]:0,d,l),d in q||(q[d]=g.start,p&&(g.end=g.start,g.start="width"===d||"height"===d?1:0))}}function Wb(a,b){var c,d,e,f,g;for(c in a)if(d=n.camelCase(c),e=b[d],f=a[c],n.isArray(f)&&(e=f[1],f=a[c]=f[0]),c!==d&&(a[d]=f,delete a[c]),g=n.cssHooks[d],g&&"expand"in g){f=g.expand(f),delete a[d];for(c in f)c in a||(a[c]=f[c],b[c]=e)}else b[d]=e}function Xb(a,b,c){var d,e,f=0,g=Qb.length,h=n.Deferred().always(function(){delete i.elem}),i=function(){if(e)return!1;for(var b=Lb||Sb(),c=Math.max(0,j.startTime+j.duration-b),d=c/j.duration||0,f=1-d,g=0,i=j.tweens.length;i>g;g++)j.tweens[g].run(f);return h.notifyWith(a,[j,f,c]),1>f&&i?c:(h.resolveWith(a,[j]),!1)},j=h.promise({elem:a,props:n.extend({},b),opts:n.extend(!0,{specialEasing:{}},c),originalProperties:b,originalOptions:c,startTime:Lb||Sb(),duration:c.duration,tweens:[],createTween:function(b,c){var d=n.Tween(a,j.opts,b,c,j.opts.specialEasing[b]||j.opts.easing);return j.tweens.push(d),d},stop:function(b){var c=0,d=b?j.tweens.length:0;if(e)return this;for(e=!0;d>c;c++)j.tweens[c].run(1);return b?h.resolveWith(a,[j,b]):h.rejectWith(a,[j,b]),this}}),k=j.props;for(Wb(k,j.opts.specialEasing);g>f;f++)if(d=Qb[f].call(j,a,k,j.opts))return d;return n.map(k,Ub,j),n.isFunction(j.opts.start)&&j.opts.start.call(a,j),n.fx.timer(n.extend(i,{elem:a,anim:j,queue:j.opts.queue})),j.progress(j.opts.progress).done(j.opts.done,j.opts.complete).fail(j.opts.fail).always(j.opts.always)}n.Animation=n.extend(Xb,{tweener:function(a,b){n.isFunction(a)?(b=a,a=["*"]):a=a.split(" ");for(var c,d=0,e=a.length;e>d;d++)c=a[d],Rb[c]=Rb[c]||[],Rb[c].unshift(b)},prefilter:function(a,b){b?Qb.unshift(a):Qb.push(a)}}),n.speed=function(a,b,c){var d=a&&"object"==typeof a?n.extend({},a):{complete:c||!c&&b||n.isFunction(a)&&a,duration:a,easing:c&&b||b&&!n.isFunction(b)&&b};return d.duration=n.fx.off?0:"number"==typeof d.duration?d.duration:d.duration in n.fx.speeds?n.fx.speeds[d.duration]:n.fx.speeds._default,(null==d.queue||d.queue===!0)&&(d.queue="fx"),d.old=d.complete,d.complete=function(){n.isFunction(d.old)&&d.old.call(this),d.queue&&n.dequeue(this,d.queue)},d},n.fn.extend({fadeTo:function(a,b,c,d){return this.filter(S).css("opacity",0).show().end().animate({opacity:b},a,c,d)},animate:function(a,b,c,d){var e=n.isEmptyObject(a),f=n.speed(b,c,d),g=function(){var b=Xb(this,n.extend({},a),f);(e||L.get(this,"finish"))&&b.stop(!0)};return g.finish=g,e||f.queue===!1?this.each(g):this.queue(f.queue,g)},stop:function(a,b,c){var d=function(a){var b=a.stop;delete a.stop,b(c)};return"string"!=typeof a&&(c=b,b=a,a=void 0),b&&a!==!1&&this.queue(a||"fx",[]),this.each(function(){var b=!0,e=null!=a&&a+"queueHooks",f=n.timers,g=L.get(this);if(e)g[e]&&g[e].stop&&d(g[e]);else for(e in g)g[e]&&g[e].stop&&Pb.test(e)&&d(g[e]);for(e=f.length;e--;)f[e].elem!==this||null!=a&&f[e].queue!==a||(f[e].anim.stop(c),b=!1,f.splice(e,1));(b||!c)&&n.dequeue(this,a)})},finish:function(a){return a!==!1&&(a=a||"fx"),this.each(function(){var b,c=L.get(this),d=c[a+"queue"],e=c[a+"queueHooks"],f=n.timers,g=d?d.length:0;for(c.finish=!0,n.queue(this,a,[]),e&&e.stop&&e.stop.call(this,!0),b=f.length;b--;)f[b].elem===this&&f[b].queue===a&&(f[b].anim.stop(!0),f.splice(b,1));for(b=0;g>b;b++)d[b]&&d[b].finish&&d[b].finish.call(this);delete c.finish})}}),n.each(["toggle","show","hide"],function(a,b){var c=n.fn[b];n.fn[b]=function(a,d,e){return null==a||"boolean"==typeof a?c.apply(this,arguments):this.animate(Tb(b,!0),a,d,e)}}),n.each({slideDown:Tb("show"),slideUp:Tb("hide"),slideToggle:Tb("toggle"),fadeIn:{opacity:"show"},fadeOut:{opacity:"hide"},fadeToggle:{opacity:"toggle"}},function(a,b){n.fn[a]=function(a,c,d){return this.animate(b,a,c,d)}}),n.timers=[],n.fx.tick=function(){var a,b=0,c=n.timers;for(Lb=n.now();b<c.length;b++)a=c[b],a()||c[b]!==a||c.splice(b--,1);c.length||n.fx.stop(),Lb=void 0},n.fx.timer=function(a){n.timers.push(a),a()?n.fx.start():n.timers.pop()},n.fx.interval=13,n.fx.start=function(){Mb||(Mb=setInterval(n.fx.tick,n.fx.interval))},n.fx.stop=function(){clearInterval(Mb),Mb=null},n.fx.speeds={slow:600,fast:200,_default:400},n.fn.delay=function(a,b){return a=n.fx?n.fx.speeds[a]||a:a,b=b||"fx",this.queue(b,function(b,c){var d=setTimeout(b,a);c.stop=function(){clearTimeout(d)}})},function(){var a=l.createElement("input"),b=l.createElement("select"),c=b.appendChild(l.createElement("option"));a.type="checkbox",k.checkOn=""!==a.value,k.optSelected=c.selected,b.disabled=!0,k.optDisabled=!c.disabled,a=l.createElement("input"),a.value="t",a.type="radio",k.radioValue="t"===a.value}();var Yb,Zb,$b=n.expr.attrHandle;n.fn.extend({attr:function(a,b){return J(this,n.attr,a,b,arguments.length>1)},removeAttr:function(a){return this.each(function(){n.removeAttr(this,a)})}}),n.extend({attr:function(a,b,c){var d,e,f=a.nodeType;if(a&&3!==f&&8!==f&&2!==f)return typeof a.getAttribute===U?n.prop(a,b,c):(1===f&&n.isXMLDoc(a)||(b=b.toLowerCase(),d=n.attrHooks[b]||(n.expr.match.bool.test(b)?Zb:Yb)),void 0===c?d&&"get"in d&&null!==(e=d.get(a,b))?e:(e=n.find.attr(a,b),null==e?void 0:e):null!==c?d&&"set"in d&&void 0!==(e=d.set(a,c,b))?e:(a.setAttribute(b,c+""),c):void n.removeAttr(a,b)) -},removeAttr:function(a,b){var c,d,e=0,f=b&&b.match(E);if(f&&1===a.nodeType)while(c=f[e++])d=n.propFix[c]||c,n.expr.match.bool.test(c)&&(a[d]=!1),a.removeAttribute(c)},attrHooks:{type:{set:function(a,b){if(!k.radioValue&&"radio"===b&&n.nodeName(a,"input")){var c=a.value;return a.setAttribute("type",b),c&&(a.value=c),b}}}}}),Zb={set:function(a,b,c){return b===!1?n.removeAttr(a,c):a.setAttribute(c,c),c}},n.each(n.expr.match.bool.source.match(/\w+/g),function(a,b){var c=$b[b]||n.find.attr;$b[b]=function(a,b,d){var e,f;return d||(f=$b[b],$b[b]=e,e=null!=c(a,b,d)?b.toLowerCase():null,$b[b]=f),e}});var _b=/^(?:input|select|textarea|button)$/i;n.fn.extend({prop:function(a,b){return J(this,n.prop,a,b,arguments.length>1)},removeProp:function(a){return this.each(function(){delete this[n.propFix[a]||a]})}}),n.extend({propFix:{"for":"htmlFor","class":"className"},prop:function(a,b,c){var d,e,f,g=a.nodeType;if(a&&3!==g&&8!==g&&2!==g)return f=1!==g||!n.isXMLDoc(a),f&&(b=n.propFix[b]||b,e=n.propHooks[b]),void 0!==c?e&&"set"in e&&void 0!==(d=e.set(a,c,b))?d:a[b]=c:e&&"get"in e&&null!==(d=e.get(a,b))?d:a[b]},propHooks:{tabIndex:{get:function(a){return a.hasAttribute("tabindex")||_b.test(a.nodeName)||a.href?a.tabIndex:-1}}}}),k.optSelected||(n.propHooks.selected={get:function(a){var b=a.parentNode;return b&&b.parentNode&&b.parentNode.selectedIndex,null}}),n.each(["tabIndex","readOnly","maxLength","cellSpacing","cellPadding","rowSpan","colSpan","useMap","frameBorder","contentEditable"],function(){n.propFix[this.toLowerCase()]=this});var ac=/[\t\r\n\f]/g;n.fn.extend({addClass:function(a){var b,c,d,e,f,g,h="string"==typeof a&&a,i=0,j=this.length;if(n.isFunction(a))return this.each(function(b){n(this).addClass(a.call(this,b,this.className))});if(h)for(b=(a||"").match(E)||[];j>i;i++)if(c=this[i],d=1===c.nodeType&&(c.className?(" "+c.className+" ").replace(ac," "):" ")){f=0;while(e=b[f++])d.indexOf(" "+e+" ")<0&&(d+=e+" ");g=n.trim(d),c.className!==g&&(c.className=g)}return this},removeClass:function(a){var b,c,d,e,f,g,h=0===arguments.length||"string"==typeof a&&a,i=0,j=this.length;if(n.isFunction(a))return this.each(function(b){n(this).removeClass(a.call(this,b,this.className))});if(h)for(b=(a||"").match(E)||[];j>i;i++)if(c=this[i],d=1===c.nodeType&&(c.className?(" "+c.className+" ").replace(ac," "):"")){f=0;while(e=b[f++])while(d.indexOf(" "+e+" ")>=0)d=d.replace(" "+e+" "," ");g=a?n.trim(d):"",c.className!==g&&(c.className=g)}return this},toggleClass:function(a,b){var c=typeof a;return"boolean"==typeof b&&"string"===c?b?this.addClass(a):this.removeClass(a):this.each(n.isFunction(a)?function(c){n(this).toggleClass(a.call(this,c,this.className,b),b)}:function(){if("string"===c){var b,d=0,e=n(this),f=a.match(E)||[];while(b=f[d++])e.hasClass(b)?e.removeClass(b):e.addClass(b)}else(c===U||"boolean"===c)&&(this.className&&L.set(this,"__className__",this.className),this.className=this.className||a===!1?"":L.get(this,"__className__")||"")})},hasClass:function(a){for(var b=" "+a+" ",c=0,d=this.length;d>c;c++)if(1===this[c].nodeType&&(" "+this[c].className+" ").replace(ac," ").indexOf(b)>=0)return!0;return!1}});var bc=/\r/g;n.fn.extend({val:function(a){var b,c,d,e=this[0];{if(arguments.length)return d=n.isFunction(a),this.each(function(c){var e;1===this.nodeType&&(e=d?a.call(this,c,n(this).val()):a,null==e?e="":"number"==typeof e?e+="":n.isArray(e)&&(e=n.map(e,function(a){return null==a?"":a+""})),b=n.valHooks[this.type]||n.valHooks[this.nodeName.toLowerCase()],b&&"set"in b&&void 0!==b.set(this,e,"value")||(this.value=e))});if(e)return b=n.valHooks[e.type]||n.valHooks[e.nodeName.toLowerCase()],b&&"get"in b&&void 0!==(c=b.get(e,"value"))?c:(c=e.value,"string"==typeof c?c.replace(bc,""):null==c?"":c)}}}),n.extend({valHooks:{option:{get:function(a){var b=n.find.attr(a,"value");return null!=b?b:n.trim(n.text(a))}},select:{get:function(a){for(var b,c,d=a.options,e=a.selectedIndex,f="select-one"===a.type||0>e,g=f?null:[],h=f?e+1:d.length,i=0>e?h:f?e:0;h>i;i++)if(c=d[i],!(!c.selected&&i!==e||(k.optDisabled?c.disabled:null!==c.getAttribute("disabled"))||c.parentNode.disabled&&n.nodeName(c.parentNode,"optgroup"))){if(b=n(c).val(),f)return b;g.push(b)}return g},set:function(a,b){var c,d,e=a.options,f=n.makeArray(b),g=e.length;while(g--)d=e[g],(d.selected=n.inArray(d.value,f)>=0)&&(c=!0);return c||(a.selectedIndex=-1),f}}}}),n.each(["radio","checkbox"],function(){n.valHooks[this]={set:function(a,b){return n.isArray(b)?a.checked=n.inArray(n(a).val(),b)>=0:void 0}},k.checkOn||(n.valHooks[this].get=function(a){return null===a.getAttribute("value")?"on":a.value})}),n.each("blur focus focusin focusout load resize scroll unload click dblclick mousedown mouseup mousemove mouseover mouseout mouseenter mouseleave change select submit keydown keypress keyup error contextmenu".split(" "),function(a,b){n.fn[b]=function(a,c){return arguments.length>0?this.on(b,null,a,c):this.trigger(b)}}),n.fn.extend({hover:function(a,b){return this.mouseenter(a).mouseleave(b||a)},bind:function(a,b,c){return this.on(a,null,b,c)},unbind:function(a,b){return this.off(a,null,b)},delegate:function(a,b,c,d){return this.on(b,a,c,d)},undelegate:function(a,b,c){return 1===arguments.length?this.off(a,"**"):this.off(b,a||"**",c)}});var cc=n.now(),dc=/\?/;n.parseJSON=function(a){return JSON.parse(a+"")},n.parseXML=function(a){var b,c;if(!a||"string"!=typeof a)return null;try{c=new DOMParser,b=c.parseFromString(a,"text/xml")}catch(d){b=void 0}return(!b||b.getElementsByTagName("parsererror").length)&&n.error("Invalid XML: "+a),b};var ec=/#.*$/,fc=/([?&])_=[^&]*/,gc=/^(.*?):[ \t]*([^\r\n]*)$/gm,hc=/^(?:about|app|app-storage|.+-extension|file|res|widget):$/,ic=/^(?:GET|HEAD)$/,jc=/^\/\//,kc=/^([\w.+-]+:)(?:\/\/(?:[^\/?#]*@|)([^\/?#:]*)(?::(\d+)|)|)/,lc={},mc={},nc="*/".concat("*"),oc=a.location.href,pc=kc.exec(oc.toLowerCase())||[];function qc(a){return function(b,c){"string"!=typeof b&&(c=b,b="*");var d,e=0,f=b.toLowerCase().match(E)||[];if(n.isFunction(c))while(d=f[e++])"+"===d[0]?(d=d.slice(1)||"*",(a[d]=a[d]||[]).unshift(c)):(a[d]=a[d]||[]).push(c)}}function rc(a,b,c,d){var e={},f=a===mc;function g(h){var i;return e[h]=!0,n.each(a[h]||[],function(a,h){var j=h(b,c,d);return"string"!=typeof j||f||e[j]?f?!(i=j):void 0:(b.dataTypes.unshift(j),g(j),!1)}),i}return g(b.dataTypes[0])||!e["*"]&&g("*")}function sc(a,b){var c,d,e=n.ajaxSettings.flatOptions||{};for(c in b)void 0!==b[c]&&((e[c]?a:d||(d={}))[c]=b[c]);return d&&n.extend(!0,a,d),a}function tc(a,b,c){var d,e,f,g,h=a.contents,i=a.dataTypes;while("*"===i[0])i.shift(),void 0===d&&(d=a.mimeType||b.getResponseHeader("Content-Type"));if(d)for(e in h)if(h[e]&&h[e].test(d)){i.unshift(e);break}if(i[0]in c)f=i[0];else{for(e in c){if(!i[0]||a.converters[e+" "+i[0]]){f=e;break}g||(g=e)}f=f||g}return f?(f!==i[0]&&i.unshift(f),c[f]):void 0}function uc(a,b,c,d){var e,f,g,h,i,j={},k=a.dataTypes.slice();if(k[1])for(g in a.converters)j[g.toLowerCase()]=a.converters[g];f=k.shift();while(f)if(a.responseFields[f]&&(c[a.responseFields[f]]=b),!i&&d&&a.dataFilter&&(b=a.dataFilter(b,a.dataType)),i=f,f=k.shift())if("*"===f)f=i;else if("*"!==i&&i!==f){if(g=j[i+" "+f]||j["* "+f],!g)for(e in j)if(h=e.split(" "),h[1]===f&&(g=j[i+" "+h[0]]||j["* "+h[0]])){g===!0?g=j[e]:j[e]!==!0&&(f=h[0],k.unshift(h[1]));break}if(g!==!0)if(g&&a["throws"])b=g(b);else try{b=g(b)}catch(l){return{state:"parsererror",error:g?l:"No conversion from "+i+" to "+f}}}return{state:"success",data:b}}n.extend({active:0,lastModified:{},etag:{},ajaxSettings:{url:oc,type:"GET",isLocal:hc.test(pc[1]),global:!0,processData:!0,async:!0,contentType:"application/x-www-form-urlencoded; charset=UTF-8",accepts:{"*":nc,text:"text/plain",html:"text/html",xml:"application/xml, text/xml",json:"application/json, text/javascript"},contents:{xml:/xml/,html:/html/,json:/json/},responseFields:{xml:"responseXML",text:"responseText",json:"responseJSON"},converters:{"* text":String,"text html":!0,"text json":n.parseJSON,"text xml":n.parseXML},flatOptions:{url:!0,context:!0}},ajaxSetup:function(a,b){return b?sc(sc(a,n.ajaxSettings),b):sc(n.ajaxSettings,a)},ajaxPrefilter:qc(lc),ajaxTransport:qc(mc),ajax:function(a,b){"object"==typeof a&&(b=a,a=void 0),b=b||{};var c,d,e,f,g,h,i,j,k=n.ajaxSetup({},b),l=k.context||k,m=k.context&&(l.nodeType||l.jquery)?n(l):n.event,o=n.Deferred(),p=n.Callbacks("once memory"),q=k.statusCode||{},r={},s={},t=0,u="canceled",v={readyState:0,getResponseHeader:function(a){var b;if(2===t){if(!f){f={};while(b=gc.exec(e))f[b[1].toLowerCase()]=b[2]}b=f[a.toLowerCase()]}return null==b?null:b},getAllResponseHeaders:function(){return 2===t?e:null},setRequestHeader:function(a,b){var c=a.toLowerCase();return t||(a=s[c]=s[c]||a,r[a]=b),this},overrideMimeType:function(a){return t||(k.mimeType=a),this},statusCode:function(a){var b;if(a)if(2>t)for(b in a)q[b]=[q[b],a[b]];else v.always(a[v.status]);return this},abort:function(a){var b=a||u;return c&&c.abort(b),x(0,b),this}};if(o.promise(v).complete=p.add,v.success=v.done,v.error=v.fail,k.url=((a||k.url||oc)+"").replace(ec,"").replace(jc,pc[1]+"//"),k.type=b.method||b.type||k.method||k.type,k.dataTypes=n.trim(k.dataType||"*").toLowerCase().match(E)||[""],null==k.crossDomain&&(h=kc.exec(k.url.toLowerCase()),k.crossDomain=!(!h||h[1]===pc[1]&&h[2]===pc[2]&&(h[3]||("http:"===h[1]?"80":"443"))===(pc[3]||("http:"===pc[1]?"80":"443")))),k.data&&k.processData&&"string"!=typeof k.data&&(k.data=n.param(k.data,k.traditional)),rc(lc,k,b,v),2===t)return v;i=n.event&&k.global,i&&0===n.active++&&n.event.trigger("ajaxStart"),k.type=k.type.toUpperCase(),k.hasContent=!ic.test(k.type),d=k.url,k.hasContent||(k.data&&(d=k.url+=(dc.test(d)?"&":"?")+k.data,delete k.data),k.cache===!1&&(k.url=fc.test(d)?d.replace(fc,"$1_="+cc++):d+(dc.test(d)?"&":"?")+"_="+cc++)),k.ifModified&&(n.lastModified[d]&&v.setRequestHeader("If-Modified-Since",n.lastModified[d]),n.etag[d]&&v.setRequestHeader("If-None-Match",n.etag[d])),(k.data&&k.hasContent&&k.contentType!==!1||b.contentType)&&v.setRequestHeader("Content-Type",k.contentType),v.setRequestHeader("Accept",k.dataTypes[0]&&k.accepts[k.dataTypes[0]]?k.accepts[k.dataTypes[0]]+("*"!==k.dataTypes[0]?", "+nc+"; q=0.01":""):k.accepts["*"]);for(j in k.headers)v.setRequestHeader(j,k.headers[j]);if(k.beforeSend&&(k.beforeSend.call(l,v,k)===!1||2===t))return v.abort();u="abort";for(j in{success:1,error:1,complete:1})v[j](k[j]);if(c=rc(mc,k,b,v)){v.readyState=1,i&&m.trigger("ajaxSend",[v,k]),k.async&&k.timeout>0&&(g=setTimeout(function(){v.abort("timeout")},k.timeout));try{t=1,c.send(r,x)}catch(w){if(!(2>t))throw w;x(-1,w)}}else x(-1,"No Transport");function x(a,b,f,h){var j,r,s,u,w,x=b;2!==t&&(t=2,g&&clearTimeout(g),c=void 0,e=h||"",v.readyState=a>0?4:0,j=a>=200&&300>a||304===a,f&&(u=tc(k,v,f)),u=uc(k,u,v,j),j?(k.ifModified&&(w=v.getResponseHeader("Last-Modified"),w&&(n.lastModified[d]=w),w=v.getResponseHeader("etag"),w&&(n.etag[d]=w)),204===a||"HEAD"===k.type?x="nocontent":304===a?x="notmodified":(x=u.state,r=u.data,s=u.error,j=!s)):(s=x,(a||!x)&&(x="error",0>a&&(a=0))),v.status=a,v.statusText=(b||x)+"",j?o.resolveWith(l,[r,x,v]):o.rejectWith(l,[v,x,s]),v.statusCode(q),q=void 0,i&&m.trigger(j?"ajaxSuccess":"ajaxError",[v,k,j?r:s]),p.fireWith(l,[v,x]),i&&(m.trigger("ajaxComplete",[v,k]),--n.active||n.event.trigger("ajaxStop")))}return v},getJSON:function(a,b,c){return n.get(a,b,c,"json")},getScript:function(a,b){return n.get(a,void 0,b,"script")}}),n.each(["get","post"],function(a,b){n[b]=function(a,c,d,e){return n.isFunction(c)&&(e=e||d,d=c,c=void 0),n.ajax({url:a,type:b,dataType:e,data:c,success:d})}}),n._evalUrl=function(a){return n.ajax({url:a,type:"GET",dataType:"script",async:!1,global:!1,"throws":!0})},n.fn.extend({wrapAll:function(a){var b;return n.isFunction(a)?this.each(function(b){n(this).wrapAll(a.call(this,b))}):(this[0]&&(b=n(a,this[0].ownerDocument).eq(0).clone(!0),this[0].parentNode&&b.insertBefore(this[0]),b.map(function(){var a=this;while(a.firstElementChild)a=a.firstElementChild;return a}).append(this)),this)},wrapInner:function(a){return this.each(n.isFunction(a)?function(b){n(this).wrapInner(a.call(this,b))}:function(){var b=n(this),c=b.contents();c.length?c.wrapAll(a):b.append(a)})},wrap:function(a){var b=n.isFunction(a);return this.each(function(c){n(this).wrapAll(b?a.call(this,c):a)})},unwrap:function(){return this.parent().each(function(){n.nodeName(this,"body")||n(this).replaceWith(this.childNodes)}).end()}}),n.expr.filters.hidden=function(a){return a.offsetWidth<=0&&a.offsetHeight<=0},n.expr.filters.visible=function(a){return!n.expr.filters.hidden(a)};var vc=/%20/g,wc=/\[\]$/,xc=/\r?\n/g,yc=/^(?:submit|button|image|reset|file)$/i,zc=/^(?:input|select|textarea|keygen)/i;function Ac(a,b,c,d){var e;if(n.isArray(b))n.each(b,function(b,e){c||wc.test(a)?d(a,e):Ac(a+"["+("object"==typeof e?b:"")+"]",e,c,d)});else if(c||"object"!==n.type(b))d(a,b);else for(e in b)Ac(a+"["+e+"]",b[e],c,d)}n.param=function(a,b){var c,d=[],e=function(a,b){b=n.isFunction(b)?b():null==b?"":b,d[d.length]=encodeURIComponent(a)+"="+encodeURIComponent(b)};if(void 0===b&&(b=n.ajaxSettings&&n.ajaxSettings.traditional),n.isArray(a)||a.jquery&&!n.isPlainObject(a))n.each(a,function(){e(this.name,this.value)});else for(c in a)Ac(c,a[c],b,e);return d.join("&").replace(vc,"+")},n.fn.extend({serialize:function(){return n.param(this.serializeArray())},serializeArray:function(){return this.map(function(){var a=n.prop(this,"elements");return a?n.makeArray(a):this}).filter(function(){var a=this.type;return this.name&&!n(this).is(":disabled")&&zc.test(this.nodeName)&&!yc.test(a)&&(this.checked||!T.test(a))}).map(function(a,b){var c=n(this).val();return null==c?null:n.isArray(c)?n.map(c,function(a){return{name:b.name,value:a.replace(xc,"\r\n")}}):{name:b.name,value:c.replace(xc,"\r\n")}}).get()}}),n.ajaxSettings.xhr=function(){try{return new XMLHttpRequest}catch(a){}};var Bc=0,Cc={},Dc={0:200,1223:204},Ec=n.ajaxSettings.xhr();a.attachEvent&&a.attachEvent("onunload",function(){for(var a in Cc)Cc[a]()}),k.cors=!!Ec&&"withCredentials"in Ec,k.ajax=Ec=!!Ec,n.ajaxTransport(function(a){var b;return k.cors||Ec&&!a.crossDomain?{send:function(c,d){var e,f=a.xhr(),g=++Bc;if(f.open(a.type,a.url,a.async,a.username,a.password),a.xhrFields)for(e in a.xhrFields)f[e]=a.xhrFields[e];a.mimeType&&f.overrideMimeType&&f.overrideMimeType(a.mimeType),a.crossDomain||c["X-Requested-With"]||(c["X-Requested-With"]="XMLHttpRequest");for(e in c)f.setRequestHeader(e,c[e]);b=function(a){return function(){b&&(delete Cc[g],b=f.onload=f.onerror=null,"abort"===a?f.abort():"error"===a?d(f.status,f.statusText):d(Dc[f.status]||f.status,f.statusText,"string"==typeof f.responseText?{text:f.responseText}:void 0,f.getAllResponseHeaders()))}},f.onload=b(),f.onerror=b("error"),b=Cc[g]=b("abort");try{f.send(a.hasContent&&a.data||null)}catch(h){if(b)throw h}},abort:function(){b&&b()}}:void 0}),n.ajaxSetup({accepts:{script:"text/javascript, application/javascript, application/ecmascript, application/x-ecmascript"},contents:{script:/(?:java|ecma)script/},converters:{"text script":function(a){return n.globalEval(a),a}}}),n.ajaxPrefilter("script",function(a){void 0===a.cache&&(a.cache=!1),a.crossDomain&&(a.type="GET")}),n.ajaxTransport("script",function(a){if(a.crossDomain){var b,c;return{send:function(d,e){b=n("<script>").prop({async:!0,charset:a.scriptCharset,src:a.url}).on("load error",c=function(a){b.remove(),c=null,a&&e("error"===a.type?404:200,a.type)}),l.head.appendChild(b[0])},abort:function(){c&&c()}}}});var Fc=[],Gc=/(=)\?(?=&|$)|\?\?/;n.ajaxSetup({jsonp:"callback",jsonpCallback:function(){var a=Fc.pop()||n.expando+"_"+cc++;return this[a]=!0,a}}),n.ajaxPrefilter("json jsonp",function(b,c,d){var e,f,g,h=b.jsonp!==!1&&(Gc.test(b.url)?"url":"string"==typeof b.data&&!(b.contentType||"").indexOf("application/x-www-form-urlencoded")&&Gc.test(b.data)&&"data");return h||"jsonp"===b.dataTypes[0]?(e=b.jsonpCallback=n.isFunction(b.jsonpCallback)?b.jsonpCallback():b.jsonpCallback,h?b[h]=b[h].replace(Gc,"$1"+e):b.jsonp!==!1&&(b.url+=(dc.test(b.url)?"&":"?")+b.jsonp+"="+e),b.converters["script json"]=function(){return g||n.error(e+" was not called"),g[0]},b.dataTypes[0]="json",f=a[e],a[e]=function(){g=arguments},d.always(function(){a[e]=f,b[e]&&(b.jsonpCallback=c.jsonpCallback,Fc.push(e)),g&&n.isFunction(f)&&f(g[0]),g=f=void 0}),"script"):void 0}),n.parseHTML=function(a,b,c){if(!a||"string"!=typeof a)return null;"boolean"==typeof b&&(c=b,b=!1),b=b||l;var d=v.exec(a),e=!c&&[];return d?[b.createElement(d[1])]:(d=n.buildFragment([a],b,e),e&&e.length&&n(e).remove(),n.merge([],d.childNodes))};var Hc=n.fn.load;n.fn.load=function(a,b,c){if("string"!=typeof a&&Hc)return Hc.apply(this,arguments);var d,e,f,g=this,h=a.indexOf(" ");return h>=0&&(d=n.trim(a.slice(h)),a=a.slice(0,h)),n.isFunction(b)?(c=b,b=void 0):b&&"object"==typeof b&&(e="POST"),g.length>0&&n.ajax({url:a,type:e,dataType:"html",data:b}).done(function(a){f=arguments,g.html(d?n("<div>").append(n.parseHTML(a)).find(d):a)}).complete(c&&function(a,b){g.each(c,f||[a.responseText,b,a])}),this},n.each(["ajaxStart","ajaxStop","ajaxComplete","ajaxError","ajaxSuccess","ajaxSend"],function(a,b){n.fn[b]=function(a){return this.on(b,a)}}),n.expr.filters.animated=function(a){return n.grep(n.timers,function(b){return a===b.elem}).length};var Ic=a.document.documentElement;function Jc(a){return n.isWindow(a)?a:9===a.nodeType&&a.defaultView}n.offset={setOffset:function(a,b,c){var d,e,f,g,h,i,j,k=n.css(a,"position"),l=n(a),m={};"static"===k&&(a.style.position="relative"),h=l.offset(),f=n.css(a,"top"),i=n.css(a,"left"),j=("absolute"===k||"fixed"===k)&&(f+i).indexOf("auto")>-1,j?(d=l.position(),g=d.top,e=d.left):(g=parseFloat(f)||0,e=parseFloat(i)||0),n.isFunction(b)&&(b=b.call(a,c,h)),null!=b.top&&(m.top=b.top-h.top+g),null!=b.left&&(m.left=b.left-h.left+e),"using"in b?b.using.call(a,m):l.css(m)}},n.fn.extend({offset:function(a){if(arguments.length)return void 0===a?this:this.each(function(b){n.offset.setOffset(this,a,b)});var b,c,d=this[0],e={top:0,left:0},f=d&&d.ownerDocument;if(f)return b=f.documentElement,n.contains(b,d)?(typeof d.getBoundingClientRect!==U&&(e=d.getBoundingClientRect()),c=Jc(f),{top:e.top+c.pageYOffset-b.clientTop,left:e.left+c.pageXOffset-b.clientLeft}):e},position:function(){if(this[0]){var a,b,c=this[0],d={top:0,left:0};return"fixed"===n.css(c,"position")?b=c.getBoundingClientRect():(a=this.offsetParent(),b=this.offset(),n.nodeName(a[0],"html")||(d=a.offset()),d.top+=n.css(a[0],"borderTopWidth",!0),d.left+=n.css(a[0],"borderLeftWidth",!0)),{top:b.top-d.top-n.css(c,"marginTop",!0),left:b.left-d.left-n.css(c,"marginLeft",!0)}}},offsetParent:function(){return this.map(function(){var a=this.offsetParent||Ic;while(a&&!n.nodeName(a,"html")&&"static"===n.css(a,"position"))a=a.offsetParent;return a||Ic})}}),n.each({scrollLeft:"pageXOffset",scrollTop:"pageYOffset"},function(b,c){var d="pageYOffset"===c;n.fn[b]=function(e){return J(this,function(b,e,f){var g=Jc(b);return void 0===f?g?g[c]:b[e]:void(g?g.scrollTo(d?a.pageXOffset:f,d?f:a.pageYOffset):b[e]=f)},b,e,arguments.length,null)}}),n.each(["top","left"],function(a,b){n.cssHooks[b]=yb(k.pixelPosition,function(a,c){return c?(c=xb(a,b),vb.test(c)?n(a).position()[b]+"px":c):void 0})}),n.each({Height:"height",Width:"width"},function(a,b){n.each({padding:"inner"+a,content:b,"":"outer"+a},function(c,d){n.fn[d]=function(d,e){var f=arguments.length&&(c||"boolean"!=typeof d),g=c||(d===!0||e===!0?"margin":"border");return J(this,function(b,c,d){var e;return n.isWindow(b)?b.document.documentElement["client"+a]:9===b.nodeType?(e=b.documentElement,Math.max(b.body["scroll"+a],e["scroll"+a],b.body["offset"+a],e["offset"+a],e["client"+a])):void 0===d?n.css(b,c,g):n.style(b,c,d,g)},b,f?d:void 0,f,null)}})}),n.fn.size=function(){return this.length},n.fn.andSelf=n.fn.addBack,"function"==typeof define&&define.amd&&define("jquery",[],function(){return n});var Kc=a.jQuery,Lc=a.$;return n.noConflict=function(b){return a.$===n&&(a.$=Lc),b&&a.jQuery===n&&(a.jQuery=Kc),n},typeof b===U&&(a.jQuery=a.$=n),n}); diff --git a/synapse/static/client/register/js/jquery-3.4.1.min.js b/synapse/static/client/register/js/jquery-3.4.1.min.js new file mode 100644 index 0000000000..a1c07fd803 --- /dev/null +++ b/synapse/static/client/register/js/jquery-3.4.1.min.js @@ -0,0 +1,2 @@ +/*! jQuery v3.4.1 | (c) JS Foundation and other contributors | jquery.org/license */ +!function(e,t){"use strict";"object"==typeof module&&"object"==typeof module.exports?module.exports=e.document?t(e,!0):function(e){if(!e.document)throw new Error("jQuery requires a window with a document");return t(e)}:t(e)}("undefined"!=typeof window?window:this,function(C,e){"use strict";var t=[],E=C.document,r=Object.getPrototypeOf,s=t.slice,g=t.concat,u=t.push,i=t.indexOf,n={},o=n.toString,v=n.hasOwnProperty,a=v.toString,l=a.call(Object),y={},m=function(e){return"function"==typeof e&&"number"!=typeof e.nodeType},x=function(e){return null!=e&&e===e.window},c={type:!0,src:!0,nonce:!0,noModule:!0};function b(e,t,n){var r,i,o=(n=n||E).createElement("script");if(o.text=e,t)for(r in c)(i=t[r]||t.getAttribute&&t.getAttribute(r))&&o.setAttribute(r,i);n.head.appendChild(o).parentNode.removeChild(o)}function w(e){return null==e?e+"":"object"==typeof e||"function"==typeof e?n[o.call(e)]||"object":typeof e}var f="3.4.1",k=function(e,t){return new k.fn.init(e,t)},p=/^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g;function d(e){var t=!!e&&"length"in e&&e.length,n=w(e);return!m(e)&&!x(e)&&("array"===n||0===t||"number"==typeof t&&0<t&&t-1 in e)}k.fn=k.prototype={jquery:f,constructor:k,length:0,toArray:function(){return s.call(this)},get:function(e){return null==e?s.call(this):e<0?this[e+this.length]:this[e]},pushStack:function(e){var t=k.merge(this.constructor(),e);return t.prevObject=this,t},each:function(e){return k.each(this,e)},map:function(n){return this.pushStack(k.map(this,function(e,t){return n.call(e,t,e)}))},slice:function(){return this.pushStack(s.apply(this,arguments))},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},eq:function(e){var t=this.length,n=+e+(e<0?t:0);return this.pushStack(0<=n&&n<t?[this[n]]:[])},end:function(){return this.prevObject||this.constructor()},push:u,sort:t.sort,splice:t.splice},k.extend=k.fn.extend=function(){var e,t,n,r,i,o,a=arguments[0]||{},s=1,u=arguments.length,l=!1;for("boolean"==typeof a&&(l=a,a=arguments[s]||{},s++),"object"==typeof a||m(a)||(a={}),s===u&&(a=this,s--);s<u;s++)if(null!=(e=arguments[s]))for(t in e)r=e[t],"__proto__"!==t&&a!==r&&(l&&r&&(k.isPlainObject(r)||(i=Array.isArray(r)))?(n=a[t],o=i&&!Array.isArray(n)?[]:i||k.isPlainObject(n)?n:{},i=!1,a[t]=k.extend(l,o,r)):void 0!==r&&(a[t]=r));return a},k.extend({expando:"jQuery"+(f+Math.random()).replace(/\D/g,""),isReady:!0,error:function(e){throw new Error(e)},noop:function(){},isPlainObject:function(e){var t,n;return!(!e||"[object Object]"!==o.call(e))&&(!(t=r(e))||"function"==typeof(n=v.call(t,"constructor")&&t.constructor)&&a.call(n)===l)},isEmptyObject:function(e){var t;for(t in e)return!1;return!0},globalEval:function(e,t){b(e,{nonce:t&&t.nonce})},each:function(e,t){var n,r=0;if(d(e)){for(n=e.length;r<n;r++)if(!1===t.call(e[r],r,e[r]))break}else for(r in e)if(!1===t.call(e[r],r,e[r]))break;return e},trim:function(e){return null==e?"":(e+"").replace(p,"")},makeArray:function(e,t){var n=t||[];return null!=e&&(d(Object(e))?k.merge(n,"string"==typeof e?[e]:e):u.call(n,e)),n},inArray:function(e,t,n){return null==t?-1:i.call(t,e,n)},merge:function(e,t){for(var n=+t.length,r=0,i=e.length;r<n;r++)e[i++]=t[r];return e.length=i,e},grep:function(e,t,n){for(var r=[],i=0,o=e.length,a=!n;i<o;i++)!t(e[i],i)!==a&&r.push(e[i]);return r},map:function(e,t,n){var r,i,o=0,a=[];if(d(e))for(r=e.length;o<r;o++)null!=(i=t(e[o],o,n))&&a.push(i);else for(o in e)null!=(i=t(e[o],o,n))&&a.push(i);return g.apply([],a)},guid:1,support:y}),"function"==typeof Symbol&&(k.fn[Symbol.iterator]=t[Symbol.iterator]),k.each("Boolean Number String Function Array Date RegExp Object Error Symbol".split(" "),function(e,t){n["[object "+t+"]"]=t.toLowerCase()});var h=function(n){var e,d,b,o,i,h,f,g,w,u,l,T,C,a,E,v,s,c,y,k="sizzle"+1*new Date,m=n.document,S=0,r=0,p=ue(),x=ue(),N=ue(),A=ue(),D=function(e,t){return e===t&&(l=!0),0},j={}.hasOwnProperty,t=[],q=t.pop,L=t.push,H=t.push,O=t.slice,P=function(e,t){for(var n=0,r=e.length;n<r;n++)if(e[n]===t)return n;return-1},R="checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped",M="[\\x20\\t\\r\\n\\f]",I="(?:\\\\.|[\\w-]|[^\0-\\xa0])+",W="\\["+M+"*("+I+")(?:"+M+"*([*^$|!~]?=)"+M+"*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|("+I+"))|)"+M+"*\\]",$=":("+I+")(?:\\((('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|((?:\\\\.|[^\\\\()[\\]]|"+W+")*)|.*)\\)|)",F=new RegExp(M+"+","g"),B=new RegExp("^"+M+"+|((?:^|[^\\\\])(?:\\\\.)*)"+M+"+$","g"),_=new RegExp("^"+M+"*,"+M+"*"),z=new RegExp("^"+M+"*([>+~]|"+M+")"+M+"*"),U=new RegExp(M+"|>"),X=new RegExp($),V=new RegExp("^"+I+"$"),G={ID:new RegExp("^#("+I+")"),CLASS:new RegExp("^\\.("+I+")"),TAG:new RegExp("^("+I+"|[*])"),ATTR:new RegExp("^"+W),PSEUDO:new RegExp("^"+$),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+M+"*(even|odd|(([+-]|)(\\d*)n|)"+M+"*(?:([+-]|)"+M+"*(\\d+)|))"+M+"*\\)|)","i"),bool:new RegExp("^(?:"+R+")$","i"),needsContext:new RegExp("^"+M+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+M+"*((?:-\\d)?\\d*)"+M+"*\\)|)(?=[^-]|$)","i")},Y=/HTML$/i,Q=/^(?:input|select|textarea|button)$/i,J=/^h\d$/i,K=/^[^{]+\{\s*\[native \w/,Z=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,ee=/[+~]/,te=new RegExp("\\\\([\\da-f]{1,6}"+M+"?|("+M+")|.)","ig"),ne=function(e,t,n){var r="0x"+t-65536;return r!=r||n?t:r<0?String.fromCharCode(r+65536):String.fromCharCode(r>>10|55296,1023&r|56320)},re=/([\0-\x1f\x7f]|^-?\d)|^-$|[^\0-\x1f\x7f-\uFFFF\w-]/g,ie=function(e,t){return t?"\0"===e?"\ufffd":e.slice(0,-1)+"\\"+e.charCodeAt(e.length-1).toString(16)+" ":"\\"+e},oe=function(){T()},ae=be(function(e){return!0===e.disabled&&"fieldset"===e.nodeName.toLowerCase()},{dir:"parentNode",next:"legend"});try{H.apply(t=O.call(m.childNodes),m.childNodes),t[m.childNodes.length].nodeType}catch(e){H={apply:t.length?function(e,t){L.apply(e,O.call(t))}:function(e,t){var n=e.length,r=0;while(e[n++]=t[r++]);e.length=n-1}}}function se(t,e,n,r){var i,o,a,s,u,l,c,f=e&&e.ownerDocument,p=e?e.nodeType:9;if(n=n||[],"string"!=typeof t||!t||1!==p&&9!==p&&11!==p)return n;if(!r&&((e?e.ownerDocument||e:m)!==C&&T(e),e=e||C,E)){if(11!==p&&(u=Z.exec(t)))if(i=u[1]){if(9===p){if(!(a=e.getElementById(i)))return n;if(a.id===i)return n.push(a),n}else if(f&&(a=f.getElementById(i))&&y(e,a)&&a.id===i)return n.push(a),n}else{if(u[2])return H.apply(n,e.getElementsByTagName(t)),n;if((i=u[3])&&d.getElementsByClassName&&e.getElementsByClassName)return H.apply(n,e.getElementsByClassName(i)),n}if(d.qsa&&!A[t+" "]&&(!v||!v.test(t))&&(1!==p||"object"!==e.nodeName.toLowerCase())){if(c=t,f=e,1===p&&U.test(t)){(s=e.getAttribute("id"))?s=s.replace(re,ie):e.setAttribute("id",s=k),o=(l=h(t)).length;while(o--)l[o]="#"+s+" "+xe(l[o]);c=l.join(","),f=ee.test(t)&&ye(e.parentNode)||e}try{return H.apply(n,f.querySelectorAll(c)),n}catch(e){A(t,!0)}finally{s===k&&e.removeAttribute("id")}}}return g(t.replace(B,"$1"),e,n,r)}function ue(){var r=[];return function e(t,n){return r.push(t+" ")>b.cacheLength&&delete e[r.shift()],e[t+" "]=n}}function le(e){return e[k]=!0,e}function ce(e){var t=C.createElement("fieldset");try{return!!e(t)}catch(e){return!1}finally{t.parentNode&&t.parentNode.removeChild(t),t=null}}function fe(e,t){var n=e.split("|"),r=n.length;while(r--)b.attrHandle[n[r]]=t}function pe(e,t){var n=t&&e,r=n&&1===e.nodeType&&1===t.nodeType&&e.sourceIndex-t.sourceIndex;if(r)return r;if(n)while(n=n.nextSibling)if(n===t)return-1;return e?1:-1}function de(t){return function(e){return"input"===e.nodeName.toLowerCase()&&e.type===t}}function he(n){return function(e){var t=e.nodeName.toLowerCase();return("input"===t||"button"===t)&&e.type===n}}function ge(t){return function(e){return"form"in e?e.parentNode&&!1===e.disabled?"label"in e?"label"in e.parentNode?e.parentNode.disabled===t:e.disabled===t:e.isDisabled===t||e.isDisabled!==!t&&ae(e)===t:e.disabled===t:"label"in e&&e.disabled===t}}function ve(a){return le(function(o){return o=+o,le(function(e,t){var n,r=a([],e.length,o),i=r.length;while(i--)e[n=r[i]]&&(e[n]=!(t[n]=e[n]))})})}function ye(e){return e&&"undefined"!=typeof e.getElementsByTagName&&e}for(e in d=se.support={},i=se.isXML=function(e){var t=e.namespaceURI,n=(e.ownerDocument||e).documentElement;return!Y.test(t||n&&n.nodeName||"HTML")},T=se.setDocument=function(e){var t,n,r=e?e.ownerDocument||e:m;return r!==C&&9===r.nodeType&&r.documentElement&&(a=(C=r).documentElement,E=!i(C),m!==C&&(n=C.defaultView)&&n.top!==n&&(n.addEventListener?n.addEventListener("unload",oe,!1):n.attachEvent&&n.attachEvent("onunload",oe)),d.attributes=ce(function(e){return e.className="i",!e.getAttribute("className")}),d.getElementsByTagName=ce(function(e){return e.appendChild(C.createComment("")),!e.getElementsByTagName("*").length}),d.getElementsByClassName=K.test(C.getElementsByClassName),d.getById=ce(function(e){return a.appendChild(e).id=k,!C.getElementsByName||!C.getElementsByName(k).length}),d.getById?(b.filter.ID=function(e){var t=e.replace(te,ne);return function(e){return e.getAttribute("id")===t}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&E){var n=t.getElementById(e);return n?[n]:[]}}):(b.filter.ID=function(e){var n=e.replace(te,ne);return function(e){var t="undefined"!=typeof e.getAttributeNode&&e.getAttributeNode("id");return t&&t.value===n}},b.find.ID=function(e,t){if("undefined"!=typeof t.getElementById&&E){var n,r,i,o=t.getElementById(e);if(o){if((n=o.getAttributeNode("id"))&&n.value===e)return[o];i=t.getElementsByName(e),r=0;while(o=i[r++])if((n=o.getAttributeNode("id"))&&n.value===e)return[o]}return[]}}),b.find.TAG=d.getElementsByTagName?function(e,t){return"undefined"!=typeof t.getElementsByTagName?t.getElementsByTagName(e):d.qsa?t.querySelectorAll(e):void 0}:function(e,t){var n,r=[],i=0,o=t.getElementsByTagName(e);if("*"===e){while(n=o[i++])1===n.nodeType&&r.push(n);return r}return o},b.find.CLASS=d.getElementsByClassName&&function(e,t){if("undefined"!=typeof t.getElementsByClassName&&E)return t.getElementsByClassName(e)},s=[],v=[],(d.qsa=K.test(C.querySelectorAll))&&(ce(function(e){a.appendChild(e).innerHTML="<a id='"+k+"'></a><select id='"+k+"-\r\\' msallowcapture=''><option selected=''></option></select>",e.querySelectorAll("[msallowcapture^='']").length&&v.push("[*^$]="+M+"*(?:''|\"\")"),e.querySelectorAll("[selected]").length||v.push("\\["+M+"*(?:value|"+R+")"),e.querySelectorAll("[id~="+k+"-]").length||v.push("~="),e.querySelectorAll(":checked").length||v.push(":checked"),e.querySelectorAll("a#"+k+"+*").length||v.push(".#.+[+~]")}),ce(function(e){e.innerHTML="<a href='' disabled='disabled'></a><select disabled='disabled'><option/></select>";var t=C.createElement("input");t.setAttribute("type","hidden"),e.appendChild(t).setAttribute("name","D"),e.querySelectorAll("[name=d]").length&&v.push("name"+M+"*[*^$|!~]?="),2!==e.querySelectorAll(":enabled").length&&v.push(":enabled",":disabled"),a.appendChild(e).disabled=!0,2!==e.querySelectorAll(":disabled").length&&v.push(":enabled",":disabled"),e.querySelectorAll("*,:x"),v.push(",.*:")})),(d.matchesSelector=K.test(c=a.matches||a.webkitMatchesSelector||a.mozMatchesSelector||a.oMatchesSelector||a.msMatchesSelector))&&ce(function(e){d.disconnectedMatch=c.call(e,"*"),c.call(e,"[s!='']:x"),s.push("!=",$)}),v=v.length&&new RegExp(v.join("|")),s=s.length&&new RegExp(s.join("|")),t=K.test(a.compareDocumentPosition),y=t||K.test(a.contains)?function(e,t){var n=9===e.nodeType?e.documentElement:e,r=t&&t.parentNode;return e===r||!(!r||1!==r.nodeType||!(n.contains?n.contains(r):e.compareDocumentPosition&&16&e.compareDocumentPosition(r)))}:function(e,t){if(t)while(t=t.parentNode)if(t===e)return!0;return!1},D=t?function(e,t){if(e===t)return l=!0,0;var n=!e.compareDocumentPosition-!t.compareDocumentPosition;return n||(1&(n=(e.ownerDocument||e)===(t.ownerDocument||t)?e.compareDocumentPosition(t):1)||!d.sortDetached&&t.compareDocumentPosition(e)===n?e===C||e.ownerDocument===m&&y(m,e)?-1:t===C||t.ownerDocument===m&&y(m,t)?1:u?P(u,e)-P(u,t):0:4&n?-1:1)}:function(e,t){if(e===t)return l=!0,0;var n,r=0,i=e.parentNode,o=t.parentNode,a=[e],s=[t];if(!i||!o)return e===C?-1:t===C?1:i?-1:o?1:u?P(u,e)-P(u,t):0;if(i===o)return pe(e,t);n=e;while(n=n.parentNode)a.unshift(n);n=t;while(n=n.parentNode)s.unshift(n);while(a[r]===s[r])r++;return r?pe(a[r],s[r]):a[r]===m?-1:s[r]===m?1:0}),C},se.matches=function(e,t){return se(e,null,null,t)},se.matchesSelector=function(e,t){if((e.ownerDocument||e)!==C&&T(e),d.matchesSelector&&E&&!A[t+" "]&&(!s||!s.test(t))&&(!v||!v.test(t)))try{var n=c.call(e,t);if(n||d.disconnectedMatch||e.document&&11!==e.document.nodeType)return n}catch(e){A(t,!0)}return 0<se(t,C,null,[e]).length},se.contains=function(e,t){return(e.ownerDocument||e)!==C&&T(e),y(e,t)},se.attr=function(e,t){(e.ownerDocument||e)!==C&&T(e);var n=b.attrHandle[t.toLowerCase()],r=n&&j.call(b.attrHandle,t.toLowerCase())?n(e,t,!E):void 0;return void 0!==r?r:d.attributes||!E?e.getAttribute(t):(r=e.getAttributeNode(t))&&r.specified?r.value:null},se.escape=function(e){return(e+"").replace(re,ie)},se.error=function(e){throw new Error("Syntax error, unrecognized expression: "+e)},se.uniqueSort=function(e){var t,n=[],r=0,i=0;if(l=!d.detectDuplicates,u=!d.sortStable&&e.slice(0),e.sort(D),l){while(t=e[i++])t===e[i]&&(r=n.push(i));while(r--)e.splice(n[r],1)}return u=null,e},o=se.getText=function(e){var t,n="",r=0,i=e.nodeType;if(i){if(1===i||9===i||11===i){if("string"==typeof e.textContent)return e.textContent;for(e=e.firstChild;e;e=e.nextSibling)n+=o(e)}else if(3===i||4===i)return e.nodeValue}else while(t=e[r++])n+=o(t);return n},(b=se.selectors={cacheLength:50,createPseudo:le,match:G,attrHandle:{},find:{},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(e){return e[1]=e[1].replace(te,ne),e[3]=(e[3]||e[4]||e[5]||"").replace(te,ne),"~="===e[2]&&(e[3]=" "+e[3]+" "),e.slice(0,4)},CHILD:function(e){return e[1]=e[1].toLowerCase(),"nth"===e[1].slice(0,3)?(e[3]||se.error(e[0]),e[4]=+(e[4]?e[5]+(e[6]||1):2*("even"===e[3]||"odd"===e[3])),e[5]=+(e[7]+e[8]||"odd"===e[3])):e[3]&&se.error(e[0]),e},PSEUDO:function(e){var t,n=!e[6]&&e[2];return G.CHILD.test(e[0])?null:(e[3]?e[2]=e[4]||e[5]||"":n&&X.test(n)&&(t=h(n,!0))&&(t=n.indexOf(")",n.length-t)-n.length)&&(e[0]=e[0].slice(0,t),e[2]=n.slice(0,t)),e.slice(0,3))}},filter:{TAG:function(e){var t=e.replace(te,ne).toLowerCase();return"*"===e?function(){return!0}:function(e){return e.nodeName&&e.nodeName.toLowerCase()===t}},CLASS:function(e){var t=p[e+" "];return t||(t=new RegExp("(^|"+M+")"+e+"("+M+"|$)"))&&p(e,function(e){return t.test("string"==typeof e.className&&e.className||"undefined"!=typeof e.getAttribute&&e.getAttribute("class")||"")})},ATTR:function(n,r,i){return function(e){var t=se.attr(e,n);return null==t?"!="===r:!r||(t+="","="===r?t===i:"!="===r?t!==i:"^="===r?i&&0===t.indexOf(i):"*="===r?i&&-1<t.indexOf(i):"$="===r?i&&t.slice(-i.length)===i:"~="===r?-1<(" "+t.replace(F," ")+" ").indexOf(i):"|="===r&&(t===i||t.slice(0,i.length+1)===i+"-"))}},CHILD:function(h,e,t,g,v){var y="nth"!==h.slice(0,3),m="last"!==h.slice(-4),x="of-type"===e;return 1===g&&0===v?function(e){return!!e.parentNode}:function(e,t,n){var r,i,o,a,s,u,l=y!==m?"nextSibling":"previousSibling",c=e.parentNode,f=x&&e.nodeName.toLowerCase(),p=!n&&!x,d=!1;if(c){if(y){while(l){a=e;while(a=a[l])if(x?a.nodeName.toLowerCase()===f:1===a.nodeType)return!1;u=l="only"===h&&!u&&"nextSibling"}return!0}if(u=[m?c.firstChild:c.lastChild],m&&p){d=(s=(r=(i=(o=(a=c)[k]||(a[k]={}))[a.uniqueID]||(o[a.uniqueID]={}))[h]||[])[0]===S&&r[1])&&r[2],a=s&&c.childNodes[s];while(a=++s&&a&&a[l]||(d=s=0)||u.pop())if(1===a.nodeType&&++d&&a===e){i[h]=[S,s,d];break}}else if(p&&(d=s=(r=(i=(o=(a=e)[k]||(a[k]={}))[a.uniqueID]||(o[a.uniqueID]={}))[h]||[])[0]===S&&r[1]),!1===d)while(a=++s&&a&&a[l]||(d=s=0)||u.pop())if((x?a.nodeName.toLowerCase()===f:1===a.nodeType)&&++d&&(p&&((i=(o=a[k]||(a[k]={}))[a.uniqueID]||(o[a.uniqueID]={}))[h]=[S,d]),a===e))break;return(d-=v)===g||d%g==0&&0<=d/g}}},PSEUDO:function(e,o){var t,a=b.pseudos[e]||b.setFilters[e.toLowerCase()]||se.error("unsupported pseudo: "+e);return a[k]?a(o):1<a.length?(t=[e,e,"",o],b.setFilters.hasOwnProperty(e.toLowerCase())?le(function(e,t){var n,r=a(e,o),i=r.length;while(i--)e[n=P(e,r[i])]=!(t[n]=r[i])}):function(e){return a(e,0,t)}):a}},pseudos:{not:le(function(e){var r=[],i=[],s=f(e.replace(B,"$1"));return s[k]?le(function(e,t,n,r){var i,o=s(e,null,r,[]),a=e.length;while(a--)(i=o[a])&&(e[a]=!(t[a]=i))}):function(e,t,n){return r[0]=e,s(r,null,n,i),r[0]=null,!i.pop()}}),has:le(function(t){return function(e){return 0<se(t,e).length}}),contains:le(function(t){return t=t.replace(te,ne),function(e){return-1<(e.textContent||o(e)).indexOf(t)}}),lang:le(function(n){return V.test(n||"")||se.error("unsupported lang: "+n),n=n.replace(te,ne).toLowerCase(),function(e){var t;do{if(t=E?e.lang:e.getAttribute("xml:lang")||e.getAttribute("lang"))return(t=t.toLowerCase())===n||0===t.indexOf(n+"-")}while((e=e.parentNode)&&1===e.nodeType);return!1}}),target:function(e){var t=n.location&&n.location.hash;return t&&t.slice(1)===e.id},root:function(e){return e===a},focus:function(e){return e===C.activeElement&&(!C.hasFocus||C.hasFocus())&&!!(e.type||e.href||~e.tabIndex)},enabled:ge(!1),disabled:ge(!0),checked:function(e){var t=e.nodeName.toLowerCase();return"input"===t&&!!e.checked||"option"===t&&!!e.selected},selected:function(e){return e.parentNode&&e.parentNode.selectedIndex,!0===e.selected},empty:function(e){for(e=e.firstChild;e;e=e.nextSibling)if(e.nodeType<6)return!1;return!0},parent:function(e){return!b.pseudos.empty(e)},header:function(e){return J.test(e.nodeName)},input:function(e){return Q.test(e.nodeName)},button:function(e){var t=e.nodeName.toLowerCase();return"input"===t&&"button"===e.type||"button"===t},text:function(e){var t;return"input"===e.nodeName.toLowerCase()&&"text"===e.type&&(null==(t=e.getAttribute("type"))||"text"===t.toLowerCase())},first:ve(function(){return[0]}),last:ve(function(e,t){return[t-1]}),eq:ve(function(e,t,n){return[n<0?n+t:n]}),even:ve(function(e,t){for(var n=0;n<t;n+=2)e.push(n);return e}),odd:ve(function(e,t){for(var n=1;n<t;n+=2)e.push(n);return e}),lt:ve(function(e,t,n){for(var r=n<0?n+t:t<n?t:n;0<=--r;)e.push(r);return e}),gt:ve(function(e,t,n){for(var r=n<0?n+t:n;++r<t;)e.push(r);return e})}}).pseudos.nth=b.pseudos.eq,{radio:!0,checkbox:!0,file:!0,password:!0,image:!0})b.pseudos[e]=de(e);for(e in{submit:!0,reset:!0})b.pseudos[e]=he(e);function me(){}function xe(e){for(var t=0,n=e.length,r="";t<n;t++)r+=e[t].value;return r}function be(s,e,t){var u=e.dir,l=e.next,c=l||u,f=t&&"parentNode"===c,p=r++;return e.first?function(e,t,n){while(e=e[u])if(1===e.nodeType||f)return s(e,t,n);return!1}:function(e,t,n){var r,i,o,a=[S,p];if(n){while(e=e[u])if((1===e.nodeType||f)&&s(e,t,n))return!0}else while(e=e[u])if(1===e.nodeType||f)if(i=(o=e[k]||(e[k]={}))[e.uniqueID]||(o[e.uniqueID]={}),l&&l===e.nodeName.toLowerCase())e=e[u]||e;else{if((r=i[c])&&r[0]===S&&r[1]===p)return a[2]=r[2];if((i[c]=a)[2]=s(e,t,n))return!0}return!1}}function we(i){return 1<i.length?function(e,t,n){var r=i.length;while(r--)if(!i[r](e,t,n))return!1;return!0}:i[0]}function Te(e,t,n,r,i){for(var o,a=[],s=0,u=e.length,l=null!=t;s<u;s++)(o=e[s])&&(n&&!n(o,r,i)||(a.push(o),l&&t.push(s)));return a}function Ce(d,h,g,v,y,e){return v&&!v[k]&&(v=Ce(v)),y&&!y[k]&&(y=Ce(y,e)),le(function(e,t,n,r){var i,o,a,s=[],u=[],l=t.length,c=e||function(e,t,n){for(var r=0,i=t.length;r<i;r++)se(e,t[r],n);return n}(h||"*",n.nodeType?[n]:n,[]),f=!d||!e&&h?c:Te(c,s,d,n,r),p=g?y||(e?d:l||v)?[]:t:f;if(g&&g(f,p,n,r),v){i=Te(p,u),v(i,[],n,r),o=i.length;while(o--)(a=i[o])&&(p[u[o]]=!(f[u[o]]=a))}if(e){if(y||d){if(y){i=[],o=p.length;while(o--)(a=p[o])&&i.push(f[o]=a);y(null,p=[],i,r)}o=p.length;while(o--)(a=p[o])&&-1<(i=y?P(e,a):s[o])&&(e[i]=!(t[i]=a))}}else p=Te(p===t?p.splice(l,p.length):p),y?y(null,t,p,r):H.apply(t,p)})}function Ee(e){for(var i,t,n,r=e.length,o=b.relative[e[0].type],a=o||b.relative[" "],s=o?1:0,u=be(function(e){return e===i},a,!0),l=be(function(e){return-1<P(i,e)},a,!0),c=[function(e,t,n){var r=!o&&(n||t!==w)||((i=t).nodeType?u(e,t,n):l(e,t,n));return i=null,r}];s<r;s++)if(t=b.relative[e[s].type])c=[be(we(c),t)];else{if((t=b.filter[e[s].type].apply(null,e[s].matches))[k]){for(n=++s;n<r;n++)if(b.relative[e[n].type])break;return Ce(1<s&&we(c),1<s&&xe(e.slice(0,s-1).concat({value:" "===e[s-2].type?"*":""})).replace(B,"$1"),t,s<n&&Ee(e.slice(s,n)),n<r&&Ee(e=e.slice(n)),n<r&&xe(e))}c.push(t)}return we(c)}return me.prototype=b.filters=b.pseudos,b.setFilters=new me,h=se.tokenize=function(e,t){var n,r,i,o,a,s,u,l=x[e+" "];if(l)return t?0:l.slice(0);a=e,s=[],u=b.preFilter;while(a){for(o in n&&!(r=_.exec(a))||(r&&(a=a.slice(r[0].length)||a),s.push(i=[])),n=!1,(r=z.exec(a))&&(n=r.shift(),i.push({value:n,type:r[0].replace(B," ")}),a=a.slice(n.length)),b.filter)!(r=G[o].exec(a))||u[o]&&!(r=u[o](r))||(n=r.shift(),i.push({value:n,type:o,matches:r}),a=a.slice(n.length));if(!n)break}return t?a.length:a?se.error(e):x(e,s).slice(0)},f=se.compile=function(e,t){var n,v,y,m,x,r,i=[],o=[],a=N[e+" "];if(!a){t||(t=h(e)),n=t.length;while(n--)(a=Ee(t[n]))[k]?i.push(a):o.push(a);(a=N(e,(v=o,m=0<(y=i).length,x=0<v.length,r=function(e,t,n,r,i){var o,a,s,u=0,l="0",c=e&&[],f=[],p=w,d=e||x&&b.find.TAG("*",i),h=S+=null==p?1:Math.random()||.1,g=d.length;for(i&&(w=t===C||t||i);l!==g&&null!=(o=d[l]);l++){if(x&&o){a=0,t||o.ownerDocument===C||(T(o),n=!E);while(s=v[a++])if(s(o,t||C,n)){r.push(o);break}i&&(S=h)}m&&((o=!s&&o)&&u--,e&&c.push(o))}if(u+=l,m&&l!==u){a=0;while(s=y[a++])s(c,f,t,n);if(e){if(0<u)while(l--)c[l]||f[l]||(f[l]=q.call(r));f=Te(f)}H.apply(r,f),i&&!e&&0<f.length&&1<u+y.length&&se.uniqueSort(r)}return i&&(S=h,w=p),c},m?le(r):r))).selector=e}return a},g=se.select=function(e,t,n,r){var i,o,a,s,u,l="function"==typeof e&&e,c=!r&&h(e=l.selector||e);if(n=n||[],1===c.length){if(2<(o=c[0]=c[0].slice(0)).length&&"ID"===(a=o[0]).type&&9===t.nodeType&&E&&b.relative[o[1].type]){if(!(t=(b.find.ID(a.matches[0].replace(te,ne),t)||[])[0]))return n;l&&(t=t.parentNode),e=e.slice(o.shift().value.length)}i=G.needsContext.test(e)?0:o.length;while(i--){if(a=o[i],b.relative[s=a.type])break;if((u=b.find[s])&&(r=u(a.matches[0].replace(te,ne),ee.test(o[0].type)&&ye(t.parentNode)||t))){if(o.splice(i,1),!(e=r.length&&xe(o)))return H.apply(n,r),n;break}}}return(l||f(e,c))(r,t,!E,n,!t||ee.test(e)&&ye(t.parentNode)||t),n},d.sortStable=k.split("").sort(D).join("")===k,d.detectDuplicates=!!l,T(),d.sortDetached=ce(function(e){return 1&e.compareDocumentPosition(C.createElement("fieldset"))}),ce(function(e){return e.innerHTML="<a href='#'></a>","#"===e.firstChild.getAttribute("href")})||fe("type|href|height|width",function(e,t,n){if(!n)return e.getAttribute(t,"type"===t.toLowerCase()?1:2)}),d.attributes&&ce(function(e){return e.innerHTML="<input/>",e.firstChild.setAttribute("value",""),""===e.firstChild.getAttribute("value")})||fe("value",function(e,t,n){if(!n&&"input"===e.nodeName.toLowerCase())return e.defaultValue}),ce(function(e){return null==e.getAttribute("disabled")})||fe(R,function(e,t,n){var r;if(!n)return!0===e[t]?t.toLowerCase():(r=e.getAttributeNode(t))&&r.specified?r.value:null}),se}(C);k.find=h,k.expr=h.selectors,k.expr[":"]=k.expr.pseudos,k.uniqueSort=k.unique=h.uniqueSort,k.text=h.getText,k.isXMLDoc=h.isXML,k.contains=h.contains,k.escapeSelector=h.escape;var T=function(e,t,n){var r=[],i=void 0!==n;while((e=e[t])&&9!==e.nodeType)if(1===e.nodeType){if(i&&k(e).is(n))break;r.push(e)}return r},S=function(e,t){for(var n=[];e;e=e.nextSibling)1===e.nodeType&&e!==t&&n.push(e);return n},N=k.expr.match.needsContext;function A(e,t){return e.nodeName&&e.nodeName.toLowerCase()===t.toLowerCase()}var D=/^<([a-z][^\/\0>:\x20\t\r\n\f]*)[\x20\t\r\n\f]*\/?>(?:<\/\1>|)$/i;function j(e,n,r){return m(n)?k.grep(e,function(e,t){return!!n.call(e,t,e)!==r}):n.nodeType?k.grep(e,function(e){return e===n!==r}):"string"!=typeof n?k.grep(e,function(e){return-1<i.call(n,e)!==r}):k.filter(n,e,r)}k.filter=function(e,t,n){var r=t[0];return n&&(e=":not("+e+")"),1===t.length&&1===r.nodeType?k.find.matchesSelector(r,e)?[r]:[]:k.find.matches(e,k.grep(t,function(e){return 1===e.nodeType}))},k.fn.extend({find:function(e){var t,n,r=this.length,i=this;if("string"!=typeof e)return this.pushStack(k(e).filter(function(){for(t=0;t<r;t++)if(k.contains(i[t],this))return!0}));for(n=this.pushStack([]),t=0;t<r;t++)k.find(e,i[t],n);return 1<r?k.uniqueSort(n):n},filter:function(e){return this.pushStack(j(this,e||[],!1))},not:function(e){return this.pushStack(j(this,e||[],!0))},is:function(e){return!!j(this,"string"==typeof e&&N.test(e)?k(e):e||[],!1).length}});var q,L=/^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]+))$/;(k.fn.init=function(e,t,n){var r,i;if(!e)return this;if(n=n||q,"string"==typeof e){if(!(r="<"===e[0]&&">"===e[e.length-1]&&3<=e.length?[null,e,null]:L.exec(e))||!r[1]&&t)return!t||t.jquery?(t||n).find(e):this.constructor(t).find(e);if(r[1]){if(t=t instanceof k?t[0]:t,k.merge(this,k.parseHTML(r[1],t&&t.nodeType?t.ownerDocument||t:E,!0)),D.test(r[1])&&k.isPlainObject(t))for(r in t)m(this[r])?this[r](t[r]):this.attr(r,t[r]);return this}return(i=E.getElementById(r[2]))&&(this[0]=i,this.length=1),this}return e.nodeType?(this[0]=e,this.length=1,this):m(e)?void 0!==n.ready?n.ready(e):e(k):k.makeArray(e,this)}).prototype=k.fn,q=k(E);var H=/^(?:parents|prev(?:Until|All))/,O={children:!0,contents:!0,next:!0,prev:!0};function P(e,t){while((e=e[t])&&1!==e.nodeType);return e}k.fn.extend({has:function(e){var t=k(e,this),n=t.length;return this.filter(function(){for(var e=0;e<n;e++)if(k.contains(this,t[e]))return!0})},closest:function(e,t){var n,r=0,i=this.length,o=[],a="string"!=typeof e&&k(e);if(!N.test(e))for(;r<i;r++)for(n=this[r];n&&n!==t;n=n.parentNode)if(n.nodeType<11&&(a?-1<a.index(n):1===n.nodeType&&k.find.matchesSelector(n,e))){o.push(n);break}return this.pushStack(1<o.length?k.uniqueSort(o):o)},index:function(e){return e?"string"==typeof e?i.call(k(e),this[0]):i.call(this,e.jquery?e[0]:e):this[0]&&this[0].parentNode?this.first().prevAll().length:-1},add:function(e,t){return this.pushStack(k.uniqueSort(k.merge(this.get(),k(e,t))))},addBack:function(e){return this.add(null==e?this.prevObject:this.prevObject.filter(e))}}),k.each({parent:function(e){var t=e.parentNode;return t&&11!==t.nodeType?t:null},parents:function(e){return T(e,"parentNode")},parentsUntil:function(e,t,n){return T(e,"parentNode",n)},next:function(e){return P(e,"nextSibling")},prev:function(e){return P(e,"previousSibling")},nextAll:function(e){return T(e,"nextSibling")},prevAll:function(e){return T(e,"previousSibling")},nextUntil:function(e,t,n){return T(e,"nextSibling",n)},prevUntil:function(e,t,n){return T(e,"previousSibling",n)},siblings:function(e){return S((e.parentNode||{}).firstChild,e)},children:function(e){return S(e.firstChild)},contents:function(e){return"undefined"!=typeof e.contentDocument?e.contentDocument:(A(e,"template")&&(e=e.content||e),k.merge([],e.childNodes))}},function(r,i){k.fn[r]=function(e,t){var n=k.map(this,i,e);return"Until"!==r.slice(-5)&&(t=e),t&&"string"==typeof t&&(n=k.filter(t,n)),1<this.length&&(O[r]||k.uniqueSort(n),H.test(r)&&n.reverse()),this.pushStack(n)}});var R=/[^\x20\t\r\n\f]+/g;function M(e){return e}function I(e){throw e}function W(e,t,n,r){var i;try{e&&m(i=e.promise)?i.call(e).done(t).fail(n):e&&m(i=e.then)?i.call(e,t,n):t.apply(void 0,[e].slice(r))}catch(e){n.apply(void 0,[e])}}k.Callbacks=function(r){var e,n;r="string"==typeof r?(e=r,n={},k.each(e.match(R)||[],function(e,t){n[t]=!0}),n):k.extend({},r);var i,t,o,a,s=[],u=[],l=-1,c=function(){for(a=a||r.once,o=i=!0;u.length;l=-1){t=u.shift();while(++l<s.length)!1===s[l].apply(t[0],t[1])&&r.stopOnFalse&&(l=s.length,t=!1)}r.memory||(t=!1),i=!1,a&&(s=t?[]:"")},f={add:function(){return s&&(t&&!i&&(l=s.length-1,u.push(t)),function n(e){k.each(e,function(e,t){m(t)?r.unique&&f.has(t)||s.push(t):t&&t.length&&"string"!==w(t)&&n(t)})}(arguments),t&&!i&&c()),this},remove:function(){return k.each(arguments,function(e,t){var n;while(-1<(n=k.inArray(t,s,n)))s.splice(n,1),n<=l&&l--}),this},has:function(e){return e?-1<k.inArray(e,s):0<s.length},empty:function(){return s&&(s=[]),this},disable:function(){return a=u=[],s=t="",this},disabled:function(){return!s},lock:function(){return a=u=[],t||i||(s=t=""),this},locked:function(){return!!a},fireWith:function(e,t){return a||(t=[e,(t=t||[]).slice?t.slice():t],u.push(t),i||c()),this},fire:function(){return f.fireWith(this,arguments),this},fired:function(){return!!o}};return f},k.extend({Deferred:function(e){var o=[["notify","progress",k.Callbacks("memory"),k.Callbacks("memory"),2],["resolve","done",k.Callbacks("once memory"),k.Callbacks("once memory"),0,"resolved"],["reject","fail",k.Callbacks("once memory"),k.Callbacks("once memory"),1,"rejected"]],i="pending",a={state:function(){return i},always:function(){return s.done(arguments).fail(arguments),this},"catch":function(e){return a.then(null,e)},pipe:function(){var i=arguments;return k.Deferred(function(r){k.each(o,function(e,t){var n=m(i[t[4]])&&i[t[4]];s[t[1]](function(){var e=n&&n.apply(this,arguments);e&&m(e.promise)?e.promise().progress(r.notify).done(r.resolve).fail(r.reject):r[t[0]+"With"](this,n?[e]:arguments)})}),i=null}).promise()},then:function(t,n,r){var u=0;function l(i,o,a,s){return function(){var n=this,r=arguments,e=function(){var e,t;if(!(i<u)){if((e=a.apply(n,r))===o.promise())throw new TypeError("Thenable self-resolution");t=e&&("object"==typeof e||"function"==typeof e)&&e.then,m(t)?s?t.call(e,l(u,o,M,s),l(u,o,I,s)):(u++,t.call(e,l(u,o,M,s),l(u,o,I,s),l(u,o,M,o.notifyWith))):(a!==M&&(n=void 0,r=[e]),(s||o.resolveWith)(n,r))}},t=s?e:function(){try{e()}catch(e){k.Deferred.exceptionHook&&k.Deferred.exceptionHook(e,t.stackTrace),u<=i+1&&(a!==I&&(n=void 0,r=[e]),o.rejectWith(n,r))}};i?t():(k.Deferred.getStackHook&&(t.stackTrace=k.Deferred.getStackHook()),C.setTimeout(t))}}return k.Deferred(function(e){o[0][3].add(l(0,e,m(r)?r:M,e.notifyWith)),o[1][3].add(l(0,e,m(t)?t:M)),o[2][3].add(l(0,e,m(n)?n:I))}).promise()},promise:function(e){return null!=e?k.extend(e,a):a}},s={};return k.each(o,function(e,t){var n=t[2],r=t[5];a[t[1]]=n.add,r&&n.add(function(){i=r},o[3-e][2].disable,o[3-e][3].disable,o[0][2].lock,o[0][3].lock),n.add(t[3].fire),s[t[0]]=function(){return s[t[0]+"With"](this===s?void 0:this,arguments),this},s[t[0]+"With"]=n.fireWith}),a.promise(s),e&&e.call(s,s),s},when:function(e){var n=arguments.length,t=n,r=Array(t),i=s.call(arguments),o=k.Deferred(),a=function(t){return function(e){r[t]=this,i[t]=1<arguments.length?s.call(arguments):e,--n||o.resolveWith(r,i)}};if(n<=1&&(W(e,o.done(a(t)).resolve,o.reject,!n),"pending"===o.state()||m(i[t]&&i[t].then)))return o.then();while(t--)W(i[t],a(t),o.reject);return o.promise()}});var $=/^(Eval|Internal|Range|Reference|Syntax|Type|URI)Error$/;k.Deferred.exceptionHook=function(e,t){C.console&&C.console.warn&&e&&$.test(e.name)&&C.console.warn("jQuery.Deferred exception: "+e.message,e.stack,t)},k.readyException=function(e){C.setTimeout(function(){throw e})};var F=k.Deferred();function B(){E.removeEventListener("DOMContentLoaded",B),C.removeEventListener("load",B),k.ready()}k.fn.ready=function(e){return F.then(e)["catch"](function(e){k.readyException(e)}),this},k.extend({isReady:!1,readyWait:1,ready:function(e){(!0===e?--k.readyWait:k.isReady)||(k.isReady=!0)!==e&&0<--k.readyWait||F.resolveWith(E,[k])}}),k.ready.then=F.then,"complete"===E.readyState||"loading"!==E.readyState&&!E.documentElement.doScroll?C.setTimeout(k.ready):(E.addEventListener("DOMContentLoaded",B),C.addEventListener("load",B));var _=function(e,t,n,r,i,o,a){var s=0,u=e.length,l=null==n;if("object"===w(n))for(s in i=!0,n)_(e,t,s,n[s],!0,o,a);else if(void 0!==r&&(i=!0,m(r)||(a=!0),l&&(a?(t.call(e,r),t=null):(l=t,t=function(e,t,n){return l.call(k(e),n)})),t))for(;s<u;s++)t(e[s],n,a?r:r.call(e[s],s,t(e[s],n)));return i?e:l?t.call(e):u?t(e[0],n):o},z=/^-ms-/,U=/-([a-z])/g;function X(e,t){return t.toUpperCase()}function V(e){return e.replace(z,"ms-").replace(U,X)}var G=function(e){return 1===e.nodeType||9===e.nodeType||!+e.nodeType};function Y(){this.expando=k.expando+Y.uid++}Y.uid=1,Y.prototype={cache:function(e){var t=e[this.expando];return t||(t={},G(e)&&(e.nodeType?e[this.expando]=t:Object.defineProperty(e,this.expando,{value:t,configurable:!0}))),t},set:function(e,t,n){var r,i=this.cache(e);if("string"==typeof t)i[V(t)]=n;else for(r in t)i[V(r)]=t[r];return i},get:function(e,t){return void 0===t?this.cache(e):e[this.expando]&&e[this.expando][V(t)]},access:function(e,t,n){return void 0===t||t&&"string"==typeof t&&void 0===n?this.get(e,t):(this.set(e,t,n),void 0!==n?n:t)},remove:function(e,t){var n,r=e[this.expando];if(void 0!==r){if(void 0!==t){n=(t=Array.isArray(t)?t.map(V):(t=V(t))in r?[t]:t.match(R)||[]).length;while(n--)delete r[t[n]]}(void 0===t||k.isEmptyObject(r))&&(e.nodeType?e[this.expando]=void 0:delete e[this.expando])}},hasData:function(e){var t=e[this.expando];return void 0!==t&&!k.isEmptyObject(t)}};var Q=new Y,J=new Y,K=/^(?:\{[\w\W]*\}|\[[\w\W]*\])$/,Z=/[A-Z]/g;function ee(e,t,n){var r,i;if(void 0===n&&1===e.nodeType)if(r="data-"+t.replace(Z,"-$&").toLowerCase(),"string"==typeof(n=e.getAttribute(r))){try{n="true"===(i=n)||"false"!==i&&("null"===i?null:i===+i+""?+i:K.test(i)?JSON.parse(i):i)}catch(e){}J.set(e,t,n)}else n=void 0;return n}k.extend({hasData:function(e){return J.hasData(e)||Q.hasData(e)},data:function(e,t,n){return J.access(e,t,n)},removeData:function(e,t){J.remove(e,t)},_data:function(e,t,n){return Q.access(e,t,n)},_removeData:function(e,t){Q.remove(e,t)}}),k.fn.extend({data:function(n,e){var t,r,i,o=this[0],a=o&&o.attributes;if(void 0===n){if(this.length&&(i=J.get(o),1===o.nodeType&&!Q.get(o,"hasDataAttrs"))){t=a.length;while(t--)a[t]&&0===(r=a[t].name).indexOf("data-")&&(r=V(r.slice(5)),ee(o,r,i[r]));Q.set(o,"hasDataAttrs",!0)}return i}return"object"==typeof n?this.each(function(){J.set(this,n)}):_(this,function(e){var t;if(o&&void 0===e)return void 0!==(t=J.get(o,n))?t:void 0!==(t=ee(o,n))?t:void 0;this.each(function(){J.set(this,n,e)})},null,e,1<arguments.length,null,!0)},removeData:function(e){return this.each(function(){J.remove(this,e)})}}),k.extend({queue:function(e,t,n){var r;if(e)return t=(t||"fx")+"queue",r=Q.get(e,t),n&&(!r||Array.isArray(n)?r=Q.access(e,t,k.makeArray(n)):r.push(n)),r||[]},dequeue:function(e,t){t=t||"fx";var n=k.queue(e,t),r=n.length,i=n.shift(),o=k._queueHooks(e,t);"inprogress"===i&&(i=n.shift(),r--),i&&("fx"===t&&n.unshift("inprogress"),delete o.stop,i.call(e,function(){k.dequeue(e,t)},o)),!r&&o&&o.empty.fire()},_queueHooks:function(e,t){var n=t+"queueHooks";return Q.get(e,n)||Q.access(e,n,{empty:k.Callbacks("once memory").add(function(){Q.remove(e,[t+"queue",n])})})}}),k.fn.extend({queue:function(t,n){var e=2;return"string"!=typeof t&&(n=t,t="fx",e--),arguments.length<e?k.queue(this[0],t):void 0===n?this:this.each(function(){var e=k.queue(this,t,n);k._queueHooks(this,t),"fx"===t&&"inprogress"!==e[0]&&k.dequeue(this,t)})},dequeue:function(e){return this.each(function(){k.dequeue(this,e)})},clearQueue:function(e){return this.queue(e||"fx",[])},promise:function(e,t){var n,r=1,i=k.Deferred(),o=this,a=this.length,s=function(){--r||i.resolveWith(o,[o])};"string"!=typeof e&&(t=e,e=void 0),e=e||"fx";while(a--)(n=Q.get(o[a],e+"queueHooks"))&&n.empty&&(r++,n.empty.add(s));return s(),i.promise(t)}});var te=/[+-]?(?:\d*\.|)\d+(?:[eE][+-]?\d+|)/.source,ne=new RegExp("^(?:([+-])=|)("+te+")([a-z%]*)$","i"),re=["Top","Right","Bottom","Left"],ie=E.documentElement,oe=function(e){return k.contains(e.ownerDocument,e)},ae={composed:!0};ie.getRootNode&&(oe=function(e){return k.contains(e.ownerDocument,e)||e.getRootNode(ae)===e.ownerDocument});var se=function(e,t){return"none"===(e=t||e).style.display||""===e.style.display&&oe(e)&&"none"===k.css(e,"display")},ue=function(e,t,n,r){var i,o,a={};for(o in t)a[o]=e.style[o],e.style[o]=t[o];for(o in i=n.apply(e,r||[]),t)e.style[o]=a[o];return i};function le(e,t,n,r){var i,o,a=20,s=r?function(){return r.cur()}:function(){return k.css(e,t,"")},u=s(),l=n&&n[3]||(k.cssNumber[t]?"":"px"),c=e.nodeType&&(k.cssNumber[t]||"px"!==l&&+u)&&ne.exec(k.css(e,t));if(c&&c[3]!==l){u/=2,l=l||c[3],c=+u||1;while(a--)k.style(e,t,c+l),(1-o)*(1-(o=s()/u||.5))<=0&&(a=0),c/=o;c*=2,k.style(e,t,c+l),n=n||[]}return n&&(c=+c||+u||0,i=n[1]?c+(n[1]+1)*n[2]:+n[2],r&&(r.unit=l,r.start=c,r.end=i)),i}var ce={};function fe(e,t){for(var n,r,i,o,a,s,u,l=[],c=0,f=e.length;c<f;c++)(r=e[c]).style&&(n=r.style.display,t?("none"===n&&(l[c]=Q.get(r,"display")||null,l[c]||(r.style.display="")),""===r.style.display&&se(r)&&(l[c]=(u=a=o=void 0,a=(i=r).ownerDocument,s=i.nodeName,(u=ce[s])||(o=a.body.appendChild(a.createElement(s)),u=k.css(o,"display"),o.parentNode.removeChild(o),"none"===u&&(u="block"),ce[s]=u)))):"none"!==n&&(l[c]="none",Q.set(r,"display",n)));for(c=0;c<f;c++)null!=l[c]&&(e[c].style.display=l[c]);return e}k.fn.extend({show:function(){return fe(this,!0)},hide:function(){return fe(this)},toggle:function(e){return"boolean"==typeof e?e?this.show():this.hide():this.each(function(){se(this)?k(this).show():k(this).hide()})}});var pe=/^(?:checkbox|radio)$/i,de=/<([a-z][^\/\0>\x20\t\r\n\f]*)/i,he=/^$|^module$|\/(?:java|ecma)script/i,ge={option:[1,"<select multiple='multiple'>","</select>"],thead:[1,"<table>","</table>"],col:[2,"<table><colgroup>","</colgroup></table>"],tr:[2,"<table><tbody>","</tbody></table>"],td:[3,"<table><tbody><tr>","</tr></tbody></table>"],_default:[0,"",""]};function ve(e,t){var n;return n="undefined"!=typeof e.getElementsByTagName?e.getElementsByTagName(t||"*"):"undefined"!=typeof e.querySelectorAll?e.querySelectorAll(t||"*"):[],void 0===t||t&&A(e,t)?k.merge([e],n):n}function ye(e,t){for(var n=0,r=e.length;n<r;n++)Q.set(e[n],"globalEval",!t||Q.get(t[n],"globalEval"))}ge.optgroup=ge.option,ge.tbody=ge.tfoot=ge.colgroup=ge.caption=ge.thead,ge.th=ge.td;var me,xe,be=/<|&#?\w+;/;function we(e,t,n,r,i){for(var o,a,s,u,l,c,f=t.createDocumentFragment(),p=[],d=0,h=e.length;d<h;d++)if((o=e[d])||0===o)if("object"===w(o))k.merge(p,o.nodeType?[o]:o);else if(be.test(o)){a=a||f.appendChild(t.createElement("div")),s=(de.exec(o)||["",""])[1].toLowerCase(),u=ge[s]||ge._default,a.innerHTML=u[1]+k.htmlPrefilter(o)+u[2],c=u[0];while(c--)a=a.lastChild;k.merge(p,a.childNodes),(a=f.firstChild).textContent=""}else p.push(t.createTextNode(o));f.textContent="",d=0;while(o=p[d++])if(r&&-1<k.inArray(o,r))i&&i.push(o);else if(l=oe(o),a=ve(f.appendChild(o),"script"),l&&ye(a),n){c=0;while(o=a[c++])he.test(o.type||"")&&n.push(o)}return f}me=E.createDocumentFragment().appendChild(E.createElement("div")),(xe=E.createElement("input")).setAttribute("type","radio"),xe.setAttribute("checked","checked"),xe.setAttribute("name","t"),me.appendChild(xe),y.checkClone=me.cloneNode(!0).cloneNode(!0).lastChild.checked,me.innerHTML="<textarea>x</textarea>",y.noCloneChecked=!!me.cloneNode(!0).lastChild.defaultValue;var Te=/^key/,Ce=/^(?:mouse|pointer|contextmenu|drag|drop)|click/,Ee=/^([^.]*)(?:\.(.+)|)/;function ke(){return!0}function Se(){return!1}function Ne(e,t){return e===function(){try{return E.activeElement}catch(e){}}()==("focus"===t)}function Ae(e,t,n,r,i,o){var a,s;if("object"==typeof t){for(s in"string"!=typeof n&&(r=r||n,n=void 0),t)Ae(e,s,n,r,t[s],o);return e}if(null==r&&null==i?(i=n,r=n=void 0):null==i&&("string"==typeof n?(i=r,r=void 0):(i=r,r=n,n=void 0)),!1===i)i=Se;else if(!i)return e;return 1===o&&(a=i,(i=function(e){return k().off(e),a.apply(this,arguments)}).guid=a.guid||(a.guid=k.guid++)),e.each(function(){k.event.add(this,t,i,r,n)})}function De(e,i,o){o?(Q.set(e,i,!1),k.event.add(e,i,{namespace:!1,handler:function(e){var t,n,r=Q.get(this,i);if(1&e.isTrigger&&this[i]){if(r.length)(k.event.special[i]||{}).delegateType&&e.stopPropagation();else if(r=s.call(arguments),Q.set(this,i,r),t=o(this,i),this[i](),r!==(n=Q.get(this,i))||t?Q.set(this,i,!1):n={},r!==n)return e.stopImmediatePropagation(),e.preventDefault(),n.value}else r.length&&(Q.set(this,i,{value:k.event.trigger(k.extend(r[0],k.Event.prototype),r.slice(1),this)}),e.stopImmediatePropagation())}})):void 0===Q.get(e,i)&&k.event.add(e,i,ke)}k.event={global:{},add:function(t,e,n,r,i){var o,a,s,u,l,c,f,p,d,h,g,v=Q.get(t);if(v){n.handler&&(n=(o=n).handler,i=o.selector),i&&k.find.matchesSelector(ie,i),n.guid||(n.guid=k.guid++),(u=v.events)||(u=v.events={}),(a=v.handle)||(a=v.handle=function(e){return"undefined"!=typeof k&&k.event.triggered!==e.type?k.event.dispatch.apply(t,arguments):void 0}),l=(e=(e||"").match(R)||[""]).length;while(l--)d=g=(s=Ee.exec(e[l])||[])[1],h=(s[2]||"").split(".").sort(),d&&(f=k.event.special[d]||{},d=(i?f.delegateType:f.bindType)||d,f=k.event.special[d]||{},c=k.extend({type:d,origType:g,data:r,handler:n,guid:n.guid,selector:i,needsContext:i&&k.expr.match.needsContext.test(i),namespace:h.join(".")},o),(p=u[d])||((p=u[d]=[]).delegateCount=0,f.setup&&!1!==f.setup.call(t,r,h,a)||t.addEventListener&&t.addEventListener(d,a)),f.add&&(f.add.call(t,c),c.handler.guid||(c.handler.guid=n.guid)),i?p.splice(p.delegateCount++,0,c):p.push(c),k.event.global[d]=!0)}},remove:function(e,t,n,r,i){var o,a,s,u,l,c,f,p,d,h,g,v=Q.hasData(e)&&Q.get(e);if(v&&(u=v.events)){l=(t=(t||"").match(R)||[""]).length;while(l--)if(d=g=(s=Ee.exec(t[l])||[])[1],h=(s[2]||"").split(".").sort(),d){f=k.event.special[d]||{},p=u[d=(r?f.delegateType:f.bindType)||d]||[],s=s[2]&&new RegExp("(^|\\.)"+h.join("\\.(?:.*\\.|)")+"(\\.|$)"),a=o=p.length;while(o--)c=p[o],!i&&g!==c.origType||n&&n.guid!==c.guid||s&&!s.test(c.namespace)||r&&r!==c.selector&&("**"!==r||!c.selector)||(p.splice(o,1),c.selector&&p.delegateCount--,f.remove&&f.remove.call(e,c));a&&!p.length&&(f.teardown&&!1!==f.teardown.call(e,h,v.handle)||k.removeEvent(e,d,v.handle),delete u[d])}else for(d in u)k.event.remove(e,d+t[l],n,r,!0);k.isEmptyObject(u)&&Q.remove(e,"handle events")}},dispatch:function(e){var t,n,r,i,o,a,s=k.event.fix(e),u=new Array(arguments.length),l=(Q.get(this,"events")||{})[s.type]||[],c=k.event.special[s.type]||{};for(u[0]=s,t=1;t<arguments.length;t++)u[t]=arguments[t];if(s.delegateTarget=this,!c.preDispatch||!1!==c.preDispatch.call(this,s)){a=k.event.handlers.call(this,s,l),t=0;while((i=a[t++])&&!s.isPropagationStopped()){s.currentTarget=i.elem,n=0;while((o=i.handlers[n++])&&!s.isImmediatePropagationStopped())s.rnamespace&&!1!==o.namespace&&!s.rnamespace.test(o.namespace)||(s.handleObj=o,s.data=o.data,void 0!==(r=((k.event.special[o.origType]||{}).handle||o.handler).apply(i.elem,u))&&!1===(s.result=r)&&(s.preventDefault(),s.stopPropagation()))}return c.postDispatch&&c.postDispatch.call(this,s),s.result}},handlers:function(e,t){var n,r,i,o,a,s=[],u=t.delegateCount,l=e.target;if(u&&l.nodeType&&!("click"===e.type&&1<=e.button))for(;l!==this;l=l.parentNode||this)if(1===l.nodeType&&("click"!==e.type||!0!==l.disabled)){for(o=[],a={},n=0;n<u;n++)void 0===a[i=(r=t[n]).selector+" "]&&(a[i]=r.needsContext?-1<k(i,this).index(l):k.find(i,this,null,[l]).length),a[i]&&o.push(r);o.length&&s.push({elem:l,handlers:o})}return l=this,u<t.length&&s.push({elem:l,handlers:t.slice(u)}),s},addProp:function(t,e){Object.defineProperty(k.Event.prototype,t,{enumerable:!0,configurable:!0,get:m(e)?function(){if(this.originalEvent)return e(this.originalEvent)}:function(){if(this.originalEvent)return this.originalEvent[t]},set:function(e){Object.defineProperty(this,t,{enumerable:!0,configurable:!0,writable:!0,value:e})}})},fix:function(e){return e[k.expando]?e:new k.Event(e)},special:{load:{noBubble:!0},click:{setup:function(e){var t=this||e;return pe.test(t.type)&&t.click&&A(t,"input")&&De(t,"click",ke),!1},trigger:function(e){var t=this||e;return pe.test(t.type)&&t.click&&A(t,"input")&&De(t,"click"),!0},_default:function(e){var t=e.target;return pe.test(t.type)&&t.click&&A(t,"input")&&Q.get(t,"click")||A(t,"a")}},beforeunload:{postDispatch:function(e){void 0!==e.result&&e.originalEvent&&(e.originalEvent.returnValue=e.result)}}}},k.removeEvent=function(e,t,n){e.removeEventListener&&e.removeEventListener(t,n)},k.Event=function(e,t){if(!(this instanceof k.Event))return new k.Event(e,t);e&&e.type?(this.originalEvent=e,this.type=e.type,this.isDefaultPrevented=e.defaultPrevented||void 0===e.defaultPrevented&&!1===e.returnValue?ke:Se,this.target=e.target&&3===e.target.nodeType?e.target.parentNode:e.target,this.currentTarget=e.currentTarget,this.relatedTarget=e.relatedTarget):this.type=e,t&&k.extend(this,t),this.timeStamp=e&&e.timeStamp||Date.now(),this[k.expando]=!0},k.Event.prototype={constructor:k.Event,isDefaultPrevented:Se,isPropagationStopped:Se,isImmediatePropagationStopped:Se,isSimulated:!1,preventDefault:function(){var e=this.originalEvent;this.isDefaultPrevented=ke,e&&!this.isSimulated&&e.preventDefault()},stopPropagation:function(){var e=this.originalEvent;this.isPropagationStopped=ke,e&&!this.isSimulated&&e.stopPropagation()},stopImmediatePropagation:function(){var e=this.originalEvent;this.isImmediatePropagationStopped=ke,e&&!this.isSimulated&&e.stopImmediatePropagation(),this.stopPropagation()}},k.each({altKey:!0,bubbles:!0,cancelable:!0,changedTouches:!0,ctrlKey:!0,detail:!0,eventPhase:!0,metaKey:!0,pageX:!0,pageY:!0,shiftKey:!0,view:!0,"char":!0,code:!0,charCode:!0,key:!0,keyCode:!0,button:!0,buttons:!0,clientX:!0,clientY:!0,offsetX:!0,offsetY:!0,pointerId:!0,pointerType:!0,screenX:!0,screenY:!0,targetTouches:!0,toElement:!0,touches:!0,which:function(e){var t=e.button;return null==e.which&&Te.test(e.type)?null!=e.charCode?e.charCode:e.keyCode:!e.which&&void 0!==t&&Ce.test(e.type)?1&t?1:2&t?3:4&t?2:0:e.which}},k.event.addProp),k.each({focus:"focusin",blur:"focusout"},function(e,t){k.event.special[e]={setup:function(){return De(this,e,Ne),!1},trigger:function(){return De(this,e),!0},delegateType:t}}),k.each({mouseenter:"mouseover",mouseleave:"mouseout",pointerenter:"pointerover",pointerleave:"pointerout"},function(e,i){k.event.special[e]={delegateType:i,bindType:i,handle:function(e){var t,n=e.relatedTarget,r=e.handleObj;return n&&(n===this||k.contains(this,n))||(e.type=r.origType,t=r.handler.apply(this,arguments),e.type=i),t}}}),k.fn.extend({on:function(e,t,n,r){return Ae(this,e,t,n,r)},one:function(e,t,n,r){return Ae(this,e,t,n,r,1)},off:function(e,t,n){var r,i;if(e&&e.preventDefault&&e.handleObj)return r=e.handleObj,k(e.delegateTarget).off(r.namespace?r.origType+"."+r.namespace:r.origType,r.selector,r.handler),this;if("object"==typeof e){for(i in e)this.off(i,t,e[i]);return this}return!1!==t&&"function"!=typeof t||(n=t,t=void 0),!1===n&&(n=Se),this.each(function(){k.event.remove(this,e,n,t)})}});var je=/<(?!area|br|col|embed|hr|img|input|link|meta|param)(([a-z][^\/\0>\x20\t\r\n\f]*)[^>]*)\/>/gi,qe=/<script|<style|<link/i,Le=/checked\s*(?:[^=]|=\s*.checked.)/i,He=/^\s*<!(?:\[CDATA\[|--)|(?:\]\]|--)>\s*$/g;function Oe(e,t){return A(e,"table")&&A(11!==t.nodeType?t:t.firstChild,"tr")&&k(e).children("tbody")[0]||e}function Pe(e){return e.type=(null!==e.getAttribute("type"))+"/"+e.type,e}function Re(e){return"true/"===(e.type||"").slice(0,5)?e.type=e.type.slice(5):e.removeAttribute("type"),e}function Me(e,t){var n,r,i,o,a,s,u,l;if(1===t.nodeType){if(Q.hasData(e)&&(o=Q.access(e),a=Q.set(t,o),l=o.events))for(i in delete a.handle,a.events={},l)for(n=0,r=l[i].length;n<r;n++)k.event.add(t,i,l[i][n]);J.hasData(e)&&(s=J.access(e),u=k.extend({},s),J.set(t,u))}}function Ie(n,r,i,o){r=g.apply([],r);var e,t,a,s,u,l,c=0,f=n.length,p=f-1,d=r[0],h=m(d);if(h||1<f&&"string"==typeof d&&!y.checkClone&&Le.test(d))return n.each(function(e){var t=n.eq(e);h&&(r[0]=d.call(this,e,t.html())),Ie(t,r,i,o)});if(f&&(t=(e=we(r,n[0].ownerDocument,!1,n,o)).firstChild,1===e.childNodes.length&&(e=t),t||o)){for(s=(a=k.map(ve(e,"script"),Pe)).length;c<f;c++)u=e,c!==p&&(u=k.clone(u,!0,!0),s&&k.merge(a,ve(u,"script"))),i.call(n[c],u,c);if(s)for(l=a[a.length-1].ownerDocument,k.map(a,Re),c=0;c<s;c++)u=a[c],he.test(u.type||"")&&!Q.access(u,"globalEval")&&k.contains(l,u)&&(u.src&&"module"!==(u.type||"").toLowerCase()?k._evalUrl&&!u.noModule&&k._evalUrl(u.src,{nonce:u.nonce||u.getAttribute("nonce")}):b(u.textContent.replace(He,""),u,l))}return n}function We(e,t,n){for(var r,i=t?k.filter(t,e):e,o=0;null!=(r=i[o]);o++)n||1!==r.nodeType||k.cleanData(ve(r)),r.parentNode&&(n&&oe(r)&&ye(ve(r,"script")),r.parentNode.removeChild(r));return e}k.extend({htmlPrefilter:function(e){return e.replace(je,"<$1></$2>")},clone:function(e,t,n){var r,i,o,a,s,u,l,c=e.cloneNode(!0),f=oe(e);if(!(y.noCloneChecked||1!==e.nodeType&&11!==e.nodeType||k.isXMLDoc(e)))for(a=ve(c),r=0,i=(o=ve(e)).length;r<i;r++)s=o[r],u=a[r],void 0,"input"===(l=u.nodeName.toLowerCase())&&pe.test(s.type)?u.checked=s.checked:"input"!==l&&"textarea"!==l||(u.defaultValue=s.defaultValue);if(t)if(n)for(o=o||ve(e),a=a||ve(c),r=0,i=o.length;r<i;r++)Me(o[r],a[r]);else Me(e,c);return 0<(a=ve(c,"script")).length&&ye(a,!f&&ve(e,"script")),c},cleanData:function(e){for(var t,n,r,i=k.event.special,o=0;void 0!==(n=e[o]);o++)if(G(n)){if(t=n[Q.expando]){if(t.events)for(r in t.events)i[r]?k.event.remove(n,r):k.removeEvent(n,r,t.handle);n[Q.expando]=void 0}n[J.expando]&&(n[J.expando]=void 0)}}}),k.fn.extend({detach:function(e){return We(this,e,!0)},remove:function(e){return We(this,e)},text:function(e){return _(this,function(e){return void 0===e?k.text(this):this.empty().each(function(){1!==this.nodeType&&11!==this.nodeType&&9!==this.nodeType||(this.textContent=e)})},null,e,arguments.length)},append:function(){return Ie(this,arguments,function(e){1!==this.nodeType&&11!==this.nodeType&&9!==this.nodeType||Oe(this,e).appendChild(e)})},prepend:function(){return Ie(this,arguments,function(e){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var t=Oe(this,e);t.insertBefore(e,t.firstChild)}})},before:function(){return Ie(this,arguments,function(e){this.parentNode&&this.parentNode.insertBefore(e,this)})},after:function(){return Ie(this,arguments,function(e){this.parentNode&&this.parentNode.insertBefore(e,this.nextSibling)})},empty:function(){for(var e,t=0;null!=(e=this[t]);t++)1===e.nodeType&&(k.cleanData(ve(e,!1)),e.textContent="");return this},clone:function(e,t){return e=null!=e&&e,t=null==t?e:t,this.map(function(){return k.clone(this,e,t)})},html:function(e){return _(this,function(e){var t=this[0]||{},n=0,r=this.length;if(void 0===e&&1===t.nodeType)return t.innerHTML;if("string"==typeof e&&!qe.test(e)&&!ge[(de.exec(e)||["",""])[1].toLowerCase()]){e=k.htmlPrefilter(e);try{for(;n<r;n++)1===(t=this[n]||{}).nodeType&&(k.cleanData(ve(t,!1)),t.innerHTML=e);t=0}catch(e){}}t&&this.empty().append(e)},null,e,arguments.length)},replaceWith:function(){var n=[];return Ie(this,arguments,function(e){var t=this.parentNode;k.inArray(this,n)<0&&(k.cleanData(ve(this)),t&&t.replaceChild(e,this))},n)}}),k.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(e,a){k.fn[e]=function(e){for(var t,n=[],r=k(e),i=r.length-1,o=0;o<=i;o++)t=o===i?this:this.clone(!0),k(r[o])[a](t),u.apply(n,t.get());return this.pushStack(n)}});var $e=new RegExp("^("+te+")(?!px)[a-z%]+$","i"),Fe=function(e){var t=e.ownerDocument.defaultView;return t&&t.opener||(t=C),t.getComputedStyle(e)},Be=new RegExp(re.join("|"),"i");function _e(e,t,n){var r,i,o,a,s=e.style;return(n=n||Fe(e))&&(""!==(a=n.getPropertyValue(t)||n[t])||oe(e)||(a=k.style(e,t)),!y.pixelBoxStyles()&&$e.test(a)&&Be.test(t)&&(r=s.width,i=s.minWidth,o=s.maxWidth,s.minWidth=s.maxWidth=s.width=a,a=n.width,s.width=r,s.minWidth=i,s.maxWidth=o)),void 0!==a?a+"":a}function ze(e,t){return{get:function(){if(!e())return(this.get=t).apply(this,arguments);delete this.get}}}!function(){function e(){if(u){s.style.cssText="position:absolute;left:-11111px;width:60px;margin-top:1px;padding:0;border:0",u.style.cssText="position:relative;display:block;box-sizing:border-box;overflow:scroll;margin:auto;border:1px;padding:1px;width:60%;top:1%",ie.appendChild(s).appendChild(u);var e=C.getComputedStyle(u);n="1%"!==e.top,a=12===t(e.marginLeft),u.style.right="60%",o=36===t(e.right),r=36===t(e.width),u.style.position="absolute",i=12===t(u.offsetWidth/3),ie.removeChild(s),u=null}}function t(e){return Math.round(parseFloat(e))}var n,r,i,o,a,s=E.createElement("div"),u=E.createElement("div");u.style&&(u.style.backgroundClip="content-box",u.cloneNode(!0).style.backgroundClip="",y.clearCloneStyle="content-box"===u.style.backgroundClip,k.extend(y,{boxSizingReliable:function(){return e(),r},pixelBoxStyles:function(){return e(),o},pixelPosition:function(){return e(),n},reliableMarginLeft:function(){return e(),a},scrollboxSize:function(){return e(),i}}))}();var Ue=["Webkit","Moz","ms"],Xe=E.createElement("div").style,Ve={};function Ge(e){var t=k.cssProps[e]||Ve[e];return t||(e in Xe?e:Ve[e]=function(e){var t=e[0].toUpperCase()+e.slice(1),n=Ue.length;while(n--)if((e=Ue[n]+t)in Xe)return e}(e)||e)}var Ye=/^(none|table(?!-c[ea]).+)/,Qe=/^--/,Je={position:"absolute",visibility:"hidden",display:"block"},Ke={letterSpacing:"0",fontWeight:"400"};function Ze(e,t,n){var r=ne.exec(t);return r?Math.max(0,r[2]-(n||0))+(r[3]||"px"):t}function et(e,t,n,r,i,o){var a="width"===t?1:0,s=0,u=0;if(n===(r?"border":"content"))return 0;for(;a<4;a+=2)"margin"===n&&(u+=k.css(e,n+re[a],!0,i)),r?("content"===n&&(u-=k.css(e,"padding"+re[a],!0,i)),"margin"!==n&&(u-=k.css(e,"border"+re[a]+"Width",!0,i))):(u+=k.css(e,"padding"+re[a],!0,i),"padding"!==n?u+=k.css(e,"border"+re[a]+"Width",!0,i):s+=k.css(e,"border"+re[a]+"Width",!0,i));return!r&&0<=o&&(u+=Math.max(0,Math.ceil(e["offset"+t[0].toUpperCase()+t.slice(1)]-o-u-s-.5))||0),u}function tt(e,t,n){var r=Fe(e),i=(!y.boxSizingReliable()||n)&&"border-box"===k.css(e,"boxSizing",!1,r),o=i,a=_e(e,t,r),s="offset"+t[0].toUpperCase()+t.slice(1);if($e.test(a)){if(!n)return a;a="auto"}return(!y.boxSizingReliable()&&i||"auto"===a||!parseFloat(a)&&"inline"===k.css(e,"display",!1,r))&&e.getClientRects().length&&(i="border-box"===k.css(e,"boxSizing",!1,r),(o=s in e)&&(a=e[s])),(a=parseFloat(a)||0)+et(e,t,n||(i?"border":"content"),o,r,a)+"px"}function nt(e,t,n,r,i){return new nt.prototype.init(e,t,n,r,i)}k.extend({cssHooks:{opacity:{get:function(e,t){if(t){var n=_e(e,"opacity");return""===n?"1":n}}}},cssNumber:{animationIterationCount:!0,columnCount:!0,fillOpacity:!0,flexGrow:!0,flexShrink:!0,fontWeight:!0,gridArea:!0,gridColumn:!0,gridColumnEnd:!0,gridColumnStart:!0,gridRow:!0,gridRowEnd:!0,gridRowStart:!0,lineHeight:!0,opacity:!0,order:!0,orphans:!0,widows:!0,zIndex:!0,zoom:!0},cssProps:{},style:function(e,t,n,r){if(e&&3!==e.nodeType&&8!==e.nodeType&&e.style){var i,o,a,s=V(t),u=Qe.test(t),l=e.style;if(u||(t=Ge(s)),a=k.cssHooks[t]||k.cssHooks[s],void 0===n)return a&&"get"in a&&void 0!==(i=a.get(e,!1,r))?i:l[t];"string"===(o=typeof n)&&(i=ne.exec(n))&&i[1]&&(n=le(e,t,i),o="number"),null!=n&&n==n&&("number"!==o||u||(n+=i&&i[3]||(k.cssNumber[s]?"":"px")),y.clearCloneStyle||""!==n||0!==t.indexOf("background")||(l[t]="inherit"),a&&"set"in a&&void 0===(n=a.set(e,n,r))||(u?l.setProperty(t,n):l[t]=n))}},css:function(e,t,n,r){var i,o,a,s=V(t);return Qe.test(t)||(t=Ge(s)),(a=k.cssHooks[t]||k.cssHooks[s])&&"get"in a&&(i=a.get(e,!0,n)),void 0===i&&(i=_e(e,t,r)),"normal"===i&&t in Ke&&(i=Ke[t]),""===n||n?(o=parseFloat(i),!0===n||isFinite(o)?o||0:i):i}}),k.each(["height","width"],function(e,u){k.cssHooks[u]={get:function(e,t,n){if(t)return!Ye.test(k.css(e,"display"))||e.getClientRects().length&&e.getBoundingClientRect().width?tt(e,u,n):ue(e,Je,function(){return tt(e,u,n)})},set:function(e,t,n){var r,i=Fe(e),o=!y.scrollboxSize()&&"absolute"===i.position,a=(o||n)&&"border-box"===k.css(e,"boxSizing",!1,i),s=n?et(e,u,n,a,i):0;return a&&o&&(s-=Math.ceil(e["offset"+u[0].toUpperCase()+u.slice(1)]-parseFloat(i[u])-et(e,u,"border",!1,i)-.5)),s&&(r=ne.exec(t))&&"px"!==(r[3]||"px")&&(e.style[u]=t,t=k.css(e,u)),Ze(0,t,s)}}}),k.cssHooks.marginLeft=ze(y.reliableMarginLeft,function(e,t){if(t)return(parseFloat(_e(e,"marginLeft"))||e.getBoundingClientRect().left-ue(e,{marginLeft:0},function(){return e.getBoundingClientRect().left}))+"px"}),k.each({margin:"",padding:"",border:"Width"},function(i,o){k.cssHooks[i+o]={expand:function(e){for(var t=0,n={},r="string"==typeof e?e.split(" "):[e];t<4;t++)n[i+re[t]+o]=r[t]||r[t-2]||r[0];return n}},"margin"!==i&&(k.cssHooks[i+o].set=Ze)}),k.fn.extend({css:function(e,t){return _(this,function(e,t,n){var r,i,o={},a=0;if(Array.isArray(t)){for(r=Fe(e),i=t.length;a<i;a++)o[t[a]]=k.css(e,t[a],!1,r);return o}return void 0!==n?k.style(e,t,n):k.css(e,t)},e,t,1<arguments.length)}}),((k.Tween=nt).prototype={constructor:nt,init:function(e,t,n,r,i,o){this.elem=e,this.prop=n,this.easing=i||k.easing._default,this.options=t,this.start=this.now=this.cur(),this.end=r,this.unit=o||(k.cssNumber[n]?"":"px")},cur:function(){var e=nt.propHooks[this.prop];return e&&e.get?e.get(this):nt.propHooks._default.get(this)},run:function(e){var t,n=nt.propHooks[this.prop];return this.options.duration?this.pos=t=k.easing[this.easing](e,this.options.duration*e,0,1,this.options.duration):this.pos=t=e,this.now=(this.end-this.start)*t+this.start,this.options.step&&this.options.step.call(this.elem,this.now,this),n&&n.set?n.set(this):nt.propHooks._default.set(this),this}}).init.prototype=nt.prototype,(nt.propHooks={_default:{get:function(e){var t;return 1!==e.elem.nodeType||null!=e.elem[e.prop]&&null==e.elem.style[e.prop]?e.elem[e.prop]:(t=k.css(e.elem,e.prop,""))&&"auto"!==t?t:0},set:function(e){k.fx.step[e.prop]?k.fx.step[e.prop](e):1!==e.elem.nodeType||!k.cssHooks[e.prop]&&null==e.elem.style[Ge(e.prop)]?e.elem[e.prop]=e.now:k.style(e.elem,e.prop,e.now+e.unit)}}}).scrollTop=nt.propHooks.scrollLeft={set:function(e){e.elem.nodeType&&e.elem.parentNode&&(e.elem[e.prop]=e.now)}},k.easing={linear:function(e){return e},swing:function(e){return.5-Math.cos(e*Math.PI)/2},_default:"swing"},k.fx=nt.prototype.init,k.fx.step={};var rt,it,ot,at,st=/^(?:toggle|show|hide)$/,ut=/queueHooks$/;function lt(){it&&(!1===E.hidden&&C.requestAnimationFrame?C.requestAnimationFrame(lt):C.setTimeout(lt,k.fx.interval),k.fx.tick())}function ct(){return C.setTimeout(function(){rt=void 0}),rt=Date.now()}function ft(e,t){var n,r=0,i={height:e};for(t=t?1:0;r<4;r+=2-t)i["margin"+(n=re[r])]=i["padding"+n]=e;return t&&(i.opacity=i.width=e),i}function pt(e,t,n){for(var r,i=(dt.tweeners[t]||[]).concat(dt.tweeners["*"]),o=0,a=i.length;o<a;o++)if(r=i[o].call(n,t,e))return r}function dt(o,e,t){var n,a,r=0,i=dt.prefilters.length,s=k.Deferred().always(function(){delete u.elem}),u=function(){if(a)return!1;for(var e=rt||ct(),t=Math.max(0,l.startTime+l.duration-e),n=1-(t/l.duration||0),r=0,i=l.tweens.length;r<i;r++)l.tweens[r].run(n);return s.notifyWith(o,[l,n,t]),n<1&&i?t:(i||s.notifyWith(o,[l,1,0]),s.resolveWith(o,[l]),!1)},l=s.promise({elem:o,props:k.extend({},e),opts:k.extend(!0,{specialEasing:{},easing:k.easing._default},t),originalProperties:e,originalOptions:t,startTime:rt||ct(),duration:t.duration,tweens:[],createTween:function(e,t){var n=k.Tween(o,l.opts,e,t,l.opts.specialEasing[e]||l.opts.easing);return l.tweens.push(n),n},stop:function(e){var t=0,n=e?l.tweens.length:0;if(a)return this;for(a=!0;t<n;t++)l.tweens[t].run(1);return e?(s.notifyWith(o,[l,1,0]),s.resolveWith(o,[l,e])):s.rejectWith(o,[l,e]),this}}),c=l.props;for(!function(e,t){var n,r,i,o,a;for(n in e)if(i=t[r=V(n)],o=e[n],Array.isArray(o)&&(i=o[1],o=e[n]=o[0]),n!==r&&(e[r]=o,delete e[n]),(a=k.cssHooks[r])&&"expand"in a)for(n in o=a.expand(o),delete e[r],o)n in e||(e[n]=o[n],t[n]=i);else t[r]=i}(c,l.opts.specialEasing);r<i;r++)if(n=dt.prefilters[r].call(l,o,c,l.opts))return m(n.stop)&&(k._queueHooks(l.elem,l.opts.queue).stop=n.stop.bind(n)),n;return k.map(c,pt,l),m(l.opts.start)&&l.opts.start.call(o,l),l.progress(l.opts.progress).done(l.opts.done,l.opts.complete).fail(l.opts.fail).always(l.opts.always),k.fx.timer(k.extend(u,{elem:o,anim:l,queue:l.opts.queue})),l}k.Animation=k.extend(dt,{tweeners:{"*":[function(e,t){var n=this.createTween(e,t);return le(n.elem,e,ne.exec(t),n),n}]},tweener:function(e,t){m(e)?(t=e,e=["*"]):e=e.match(R);for(var n,r=0,i=e.length;r<i;r++)n=e[r],dt.tweeners[n]=dt.tweeners[n]||[],dt.tweeners[n].unshift(t)},prefilters:[function(e,t,n){var r,i,o,a,s,u,l,c,f="width"in t||"height"in t,p=this,d={},h=e.style,g=e.nodeType&&se(e),v=Q.get(e,"fxshow");for(r in n.queue||(null==(a=k._queueHooks(e,"fx")).unqueued&&(a.unqueued=0,s=a.empty.fire,a.empty.fire=function(){a.unqueued||s()}),a.unqueued++,p.always(function(){p.always(function(){a.unqueued--,k.queue(e,"fx").length||a.empty.fire()})})),t)if(i=t[r],st.test(i)){if(delete t[r],o=o||"toggle"===i,i===(g?"hide":"show")){if("show"!==i||!v||void 0===v[r])continue;g=!0}d[r]=v&&v[r]||k.style(e,r)}if((u=!k.isEmptyObject(t))||!k.isEmptyObject(d))for(r in f&&1===e.nodeType&&(n.overflow=[h.overflow,h.overflowX,h.overflowY],null==(l=v&&v.display)&&(l=Q.get(e,"display")),"none"===(c=k.css(e,"display"))&&(l?c=l:(fe([e],!0),l=e.style.display||l,c=k.css(e,"display"),fe([e]))),("inline"===c||"inline-block"===c&&null!=l)&&"none"===k.css(e,"float")&&(u||(p.done(function(){h.display=l}),null==l&&(c=h.display,l="none"===c?"":c)),h.display="inline-block")),n.overflow&&(h.overflow="hidden",p.always(function(){h.overflow=n.overflow[0],h.overflowX=n.overflow[1],h.overflowY=n.overflow[2]})),u=!1,d)u||(v?"hidden"in v&&(g=v.hidden):v=Q.access(e,"fxshow",{display:l}),o&&(v.hidden=!g),g&&fe([e],!0),p.done(function(){for(r in g||fe([e]),Q.remove(e,"fxshow"),d)k.style(e,r,d[r])})),u=pt(g?v[r]:0,r,p),r in v||(v[r]=u.start,g&&(u.end=u.start,u.start=0))}],prefilter:function(e,t){t?dt.prefilters.unshift(e):dt.prefilters.push(e)}}),k.speed=function(e,t,n){var r=e&&"object"==typeof e?k.extend({},e):{complete:n||!n&&t||m(e)&&e,duration:e,easing:n&&t||t&&!m(t)&&t};return k.fx.off?r.duration=0:"number"!=typeof r.duration&&(r.duration in k.fx.speeds?r.duration=k.fx.speeds[r.duration]:r.duration=k.fx.speeds._default),null!=r.queue&&!0!==r.queue||(r.queue="fx"),r.old=r.complete,r.complete=function(){m(r.old)&&r.old.call(this),r.queue&&k.dequeue(this,r.queue)},r},k.fn.extend({fadeTo:function(e,t,n,r){return this.filter(se).css("opacity",0).show().end().animate({opacity:t},e,n,r)},animate:function(t,e,n,r){var i=k.isEmptyObject(t),o=k.speed(e,n,r),a=function(){var e=dt(this,k.extend({},t),o);(i||Q.get(this,"finish"))&&e.stop(!0)};return a.finish=a,i||!1===o.queue?this.each(a):this.queue(o.queue,a)},stop:function(i,e,o){var a=function(e){var t=e.stop;delete e.stop,t(o)};return"string"!=typeof i&&(o=e,e=i,i=void 0),e&&!1!==i&&this.queue(i||"fx",[]),this.each(function(){var e=!0,t=null!=i&&i+"queueHooks",n=k.timers,r=Q.get(this);if(t)r[t]&&r[t].stop&&a(r[t]);else for(t in r)r[t]&&r[t].stop&&ut.test(t)&&a(r[t]);for(t=n.length;t--;)n[t].elem!==this||null!=i&&n[t].queue!==i||(n[t].anim.stop(o),e=!1,n.splice(t,1));!e&&o||k.dequeue(this,i)})},finish:function(a){return!1!==a&&(a=a||"fx"),this.each(function(){var e,t=Q.get(this),n=t[a+"queue"],r=t[a+"queueHooks"],i=k.timers,o=n?n.length:0;for(t.finish=!0,k.queue(this,a,[]),r&&r.stop&&r.stop.call(this,!0),e=i.length;e--;)i[e].elem===this&&i[e].queue===a&&(i[e].anim.stop(!0),i.splice(e,1));for(e=0;e<o;e++)n[e]&&n[e].finish&&n[e].finish.call(this);delete t.finish})}}),k.each(["toggle","show","hide"],function(e,r){var i=k.fn[r];k.fn[r]=function(e,t,n){return null==e||"boolean"==typeof e?i.apply(this,arguments):this.animate(ft(r,!0),e,t,n)}}),k.each({slideDown:ft("show"),slideUp:ft("hide"),slideToggle:ft("toggle"),fadeIn:{opacity:"show"},fadeOut:{opacity:"hide"},fadeToggle:{opacity:"toggle"}},function(e,r){k.fn[e]=function(e,t,n){return this.animate(r,e,t,n)}}),k.timers=[],k.fx.tick=function(){var e,t=0,n=k.timers;for(rt=Date.now();t<n.length;t++)(e=n[t])()||n[t]!==e||n.splice(t--,1);n.length||k.fx.stop(),rt=void 0},k.fx.timer=function(e){k.timers.push(e),k.fx.start()},k.fx.interval=13,k.fx.start=function(){it||(it=!0,lt())},k.fx.stop=function(){it=null},k.fx.speeds={slow:600,fast:200,_default:400},k.fn.delay=function(r,e){return r=k.fx&&k.fx.speeds[r]||r,e=e||"fx",this.queue(e,function(e,t){var n=C.setTimeout(e,r);t.stop=function(){C.clearTimeout(n)}})},ot=E.createElement("input"),at=E.createElement("select").appendChild(E.createElement("option")),ot.type="checkbox",y.checkOn=""!==ot.value,y.optSelected=at.selected,(ot=E.createElement("input")).value="t",ot.type="radio",y.radioValue="t"===ot.value;var ht,gt=k.expr.attrHandle;k.fn.extend({attr:function(e,t){return _(this,k.attr,e,t,1<arguments.length)},removeAttr:function(e){return this.each(function(){k.removeAttr(this,e)})}}),k.extend({attr:function(e,t,n){var r,i,o=e.nodeType;if(3!==o&&8!==o&&2!==o)return"undefined"==typeof e.getAttribute?k.prop(e,t,n):(1===o&&k.isXMLDoc(e)||(i=k.attrHooks[t.toLowerCase()]||(k.expr.match.bool.test(t)?ht:void 0)),void 0!==n?null===n?void k.removeAttr(e,t):i&&"set"in i&&void 0!==(r=i.set(e,n,t))?r:(e.setAttribute(t,n+""),n):i&&"get"in i&&null!==(r=i.get(e,t))?r:null==(r=k.find.attr(e,t))?void 0:r)},attrHooks:{type:{set:function(e,t){if(!y.radioValue&&"radio"===t&&A(e,"input")){var n=e.value;return e.setAttribute("type",t),n&&(e.value=n),t}}}},removeAttr:function(e,t){var n,r=0,i=t&&t.match(R);if(i&&1===e.nodeType)while(n=i[r++])e.removeAttribute(n)}}),ht={set:function(e,t,n){return!1===t?k.removeAttr(e,n):e.setAttribute(n,n),n}},k.each(k.expr.match.bool.source.match(/\w+/g),function(e,t){var a=gt[t]||k.find.attr;gt[t]=function(e,t,n){var r,i,o=t.toLowerCase();return n||(i=gt[o],gt[o]=r,r=null!=a(e,t,n)?o:null,gt[o]=i),r}});var vt=/^(?:input|select|textarea|button)$/i,yt=/^(?:a|area)$/i;function mt(e){return(e.match(R)||[]).join(" ")}function xt(e){return e.getAttribute&&e.getAttribute("class")||""}function bt(e){return Array.isArray(e)?e:"string"==typeof e&&e.match(R)||[]}k.fn.extend({prop:function(e,t){return _(this,k.prop,e,t,1<arguments.length)},removeProp:function(e){return this.each(function(){delete this[k.propFix[e]||e]})}}),k.extend({prop:function(e,t,n){var r,i,o=e.nodeType;if(3!==o&&8!==o&&2!==o)return 1===o&&k.isXMLDoc(e)||(t=k.propFix[t]||t,i=k.propHooks[t]),void 0!==n?i&&"set"in i&&void 0!==(r=i.set(e,n,t))?r:e[t]=n:i&&"get"in i&&null!==(r=i.get(e,t))?r:e[t]},propHooks:{tabIndex:{get:function(e){var t=k.find.attr(e,"tabindex");return t?parseInt(t,10):vt.test(e.nodeName)||yt.test(e.nodeName)&&e.href?0:-1}}},propFix:{"for":"htmlFor","class":"className"}}),y.optSelected||(k.propHooks.selected={get:function(e){var t=e.parentNode;return t&&t.parentNode&&t.parentNode.selectedIndex,null},set:function(e){var t=e.parentNode;t&&(t.selectedIndex,t.parentNode&&t.parentNode.selectedIndex)}}),k.each(["tabIndex","readOnly","maxLength","cellSpacing","cellPadding","rowSpan","colSpan","useMap","frameBorder","contentEditable"],function(){k.propFix[this.toLowerCase()]=this}),k.fn.extend({addClass:function(t){var e,n,r,i,o,a,s,u=0;if(m(t))return this.each(function(e){k(this).addClass(t.call(this,e,xt(this)))});if((e=bt(t)).length)while(n=this[u++])if(i=xt(n),r=1===n.nodeType&&" "+mt(i)+" "){a=0;while(o=e[a++])r.indexOf(" "+o+" ")<0&&(r+=o+" ");i!==(s=mt(r))&&n.setAttribute("class",s)}return this},removeClass:function(t){var e,n,r,i,o,a,s,u=0;if(m(t))return this.each(function(e){k(this).removeClass(t.call(this,e,xt(this)))});if(!arguments.length)return this.attr("class","");if((e=bt(t)).length)while(n=this[u++])if(i=xt(n),r=1===n.nodeType&&" "+mt(i)+" "){a=0;while(o=e[a++])while(-1<r.indexOf(" "+o+" "))r=r.replace(" "+o+" "," ");i!==(s=mt(r))&&n.setAttribute("class",s)}return this},toggleClass:function(i,t){var o=typeof i,a="string"===o||Array.isArray(i);return"boolean"==typeof t&&a?t?this.addClass(i):this.removeClass(i):m(i)?this.each(function(e){k(this).toggleClass(i.call(this,e,xt(this),t),t)}):this.each(function(){var e,t,n,r;if(a){t=0,n=k(this),r=bt(i);while(e=r[t++])n.hasClass(e)?n.removeClass(e):n.addClass(e)}else void 0!==i&&"boolean"!==o||((e=xt(this))&&Q.set(this,"__className__",e),this.setAttribute&&this.setAttribute("class",e||!1===i?"":Q.get(this,"__className__")||""))})},hasClass:function(e){var t,n,r=0;t=" "+e+" ";while(n=this[r++])if(1===n.nodeType&&-1<(" "+mt(xt(n))+" ").indexOf(t))return!0;return!1}});var wt=/\r/g;k.fn.extend({val:function(n){var r,e,i,t=this[0];return arguments.length?(i=m(n),this.each(function(e){var t;1===this.nodeType&&(null==(t=i?n.call(this,e,k(this).val()):n)?t="":"number"==typeof t?t+="":Array.isArray(t)&&(t=k.map(t,function(e){return null==e?"":e+""})),(r=k.valHooks[this.type]||k.valHooks[this.nodeName.toLowerCase()])&&"set"in r&&void 0!==r.set(this,t,"value")||(this.value=t))})):t?(r=k.valHooks[t.type]||k.valHooks[t.nodeName.toLowerCase()])&&"get"in r&&void 0!==(e=r.get(t,"value"))?e:"string"==typeof(e=t.value)?e.replace(wt,""):null==e?"":e:void 0}}),k.extend({valHooks:{option:{get:function(e){var t=k.find.attr(e,"value");return null!=t?t:mt(k.text(e))}},select:{get:function(e){var t,n,r,i=e.options,o=e.selectedIndex,a="select-one"===e.type,s=a?null:[],u=a?o+1:i.length;for(r=o<0?u:a?o:0;r<u;r++)if(((n=i[r]).selected||r===o)&&!n.disabled&&(!n.parentNode.disabled||!A(n.parentNode,"optgroup"))){if(t=k(n).val(),a)return t;s.push(t)}return s},set:function(e,t){var n,r,i=e.options,o=k.makeArray(t),a=i.length;while(a--)((r=i[a]).selected=-1<k.inArray(k.valHooks.option.get(r),o))&&(n=!0);return n||(e.selectedIndex=-1),o}}}}),k.each(["radio","checkbox"],function(){k.valHooks[this]={set:function(e,t){if(Array.isArray(t))return e.checked=-1<k.inArray(k(e).val(),t)}},y.checkOn||(k.valHooks[this].get=function(e){return null===e.getAttribute("value")?"on":e.value})}),y.focusin="onfocusin"in C;var Tt=/^(?:focusinfocus|focusoutblur)$/,Ct=function(e){e.stopPropagation()};k.extend(k.event,{trigger:function(e,t,n,r){var i,o,a,s,u,l,c,f,p=[n||E],d=v.call(e,"type")?e.type:e,h=v.call(e,"namespace")?e.namespace.split("."):[];if(o=f=a=n=n||E,3!==n.nodeType&&8!==n.nodeType&&!Tt.test(d+k.event.triggered)&&(-1<d.indexOf(".")&&(d=(h=d.split(".")).shift(),h.sort()),u=d.indexOf(":")<0&&"on"+d,(e=e[k.expando]?e:new k.Event(d,"object"==typeof e&&e)).isTrigger=r?2:3,e.namespace=h.join("."),e.rnamespace=e.namespace?new RegExp("(^|\\.)"+h.join("\\.(?:.*\\.|)")+"(\\.|$)"):null,e.result=void 0,e.target||(e.target=n),t=null==t?[e]:k.makeArray(t,[e]),c=k.event.special[d]||{},r||!c.trigger||!1!==c.trigger.apply(n,t))){if(!r&&!c.noBubble&&!x(n)){for(s=c.delegateType||d,Tt.test(s+d)||(o=o.parentNode);o;o=o.parentNode)p.push(o),a=o;a===(n.ownerDocument||E)&&p.push(a.defaultView||a.parentWindow||C)}i=0;while((o=p[i++])&&!e.isPropagationStopped())f=o,e.type=1<i?s:c.bindType||d,(l=(Q.get(o,"events")||{})[e.type]&&Q.get(o,"handle"))&&l.apply(o,t),(l=u&&o[u])&&l.apply&&G(o)&&(e.result=l.apply(o,t),!1===e.result&&e.preventDefault());return e.type=d,r||e.isDefaultPrevented()||c._default&&!1!==c._default.apply(p.pop(),t)||!G(n)||u&&m(n[d])&&!x(n)&&((a=n[u])&&(n[u]=null),k.event.triggered=d,e.isPropagationStopped()&&f.addEventListener(d,Ct),n[d](),e.isPropagationStopped()&&f.removeEventListener(d,Ct),k.event.triggered=void 0,a&&(n[u]=a)),e.result}},simulate:function(e,t,n){var r=k.extend(new k.Event,n,{type:e,isSimulated:!0});k.event.trigger(r,null,t)}}),k.fn.extend({trigger:function(e,t){return this.each(function(){k.event.trigger(e,t,this)})},triggerHandler:function(e,t){var n=this[0];if(n)return k.event.trigger(e,t,n,!0)}}),y.focusin||k.each({focus:"focusin",blur:"focusout"},function(n,r){var i=function(e){k.event.simulate(r,e.target,k.event.fix(e))};k.event.special[r]={setup:function(){var e=this.ownerDocument||this,t=Q.access(e,r);t||e.addEventListener(n,i,!0),Q.access(e,r,(t||0)+1)},teardown:function(){var e=this.ownerDocument||this,t=Q.access(e,r)-1;t?Q.access(e,r,t):(e.removeEventListener(n,i,!0),Q.remove(e,r))}}});var Et=C.location,kt=Date.now(),St=/\?/;k.parseXML=function(e){var t;if(!e||"string"!=typeof e)return null;try{t=(new C.DOMParser).parseFromString(e,"text/xml")}catch(e){t=void 0}return t&&!t.getElementsByTagName("parsererror").length||k.error("Invalid XML: "+e),t};var Nt=/\[\]$/,At=/\r?\n/g,Dt=/^(?:submit|button|image|reset|file)$/i,jt=/^(?:input|select|textarea|keygen)/i;function qt(n,e,r,i){var t;if(Array.isArray(e))k.each(e,function(e,t){r||Nt.test(n)?i(n,t):qt(n+"["+("object"==typeof t&&null!=t?e:"")+"]",t,r,i)});else if(r||"object"!==w(e))i(n,e);else for(t in e)qt(n+"["+t+"]",e[t],r,i)}k.param=function(e,t){var n,r=[],i=function(e,t){var n=m(t)?t():t;r[r.length]=encodeURIComponent(e)+"="+encodeURIComponent(null==n?"":n)};if(null==e)return"";if(Array.isArray(e)||e.jquery&&!k.isPlainObject(e))k.each(e,function(){i(this.name,this.value)});else for(n in e)qt(n,e[n],t,i);return r.join("&")},k.fn.extend({serialize:function(){return k.param(this.serializeArray())},serializeArray:function(){return this.map(function(){var e=k.prop(this,"elements");return e?k.makeArray(e):this}).filter(function(){var e=this.type;return this.name&&!k(this).is(":disabled")&&jt.test(this.nodeName)&&!Dt.test(e)&&(this.checked||!pe.test(e))}).map(function(e,t){var n=k(this).val();return null==n?null:Array.isArray(n)?k.map(n,function(e){return{name:t.name,value:e.replace(At,"\r\n")}}):{name:t.name,value:n.replace(At,"\r\n")}}).get()}});var Lt=/%20/g,Ht=/#.*$/,Ot=/([?&])_=[^&]*/,Pt=/^(.*?):[ \t]*([^\r\n]*)$/gm,Rt=/^(?:GET|HEAD)$/,Mt=/^\/\//,It={},Wt={},$t="*/".concat("*"),Ft=E.createElement("a");function Bt(o){return function(e,t){"string"!=typeof e&&(t=e,e="*");var n,r=0,i=e.toLowerCase().match(R)||[];if(m(t))while(n=i[r++])"+"===n[0]?(n=n.slice(1)||"*",(o[n]=o[n]||[]).unshift(t)):(o[n]=o[n]||[]).push(t)}}function _t(t,i,o,a){var s={},u=t===Wt;function l(e){var r;return s[e]=!0,k.each(t[e]||[],function(e,t){var n=t(i,o,a);return"string"!=typeof n||u||s[n]?u?!(r=n):void 0:(i.dataTypes.unshift(n),l(n),!1)}),r}return l(i.dataTypes[0])||!s["*"]&&l("*")}function zt(e,t){var n,r,i=k.ajaxSettings.flatOptions||{};for(n in t)void 0!==t[n]&&((i[n]?e:r||(r={}))[n]=t[n]);return r&&k.extend(!0,e,r),e}Ft.href=Et.href,k.extend({active:0,lastModified:{},etag:{},ajaxSettings:{url:Et.href,type:"GET",isLocal:/^(?:about|app|app-storage|.+-extension|file|res|widget):$/.test(Et.protocol),global:!0,processData:!0,async:!0,contentType:"application/x-www-form-urlencoded; charset=UTF-8",accepts:{"*":$t,text:"text/plain",html:"text/html",xml:"application/xml, text/xml",json:"application/json, text/javascript"},contents:{xml:/\bxml\b/,html:/\bhtml/,json:/\bjson\b/},responseFields:{xml:"responseXML",text:"responseText",json:"responseJSON"},converters:{"* text":String,"text html":!0,"text json":JSON.parse,"text xml":k.parseXML},flatOptions:{url:!0,context:!0}},ajaxSetup:function(e,t){return t?zt(zt(e,k.ajaxSettings),t):zt(k.ajaxSettings,e)},ajaxPrefilter:Bt(It),ajaxTransport:Bt(Wt),ajax:function(e,t){"object"==typeof e&&(t=e,e=void 0),t=t||{};var c,f,p,n,d,r,h,g,i,o,v=k.ajaxSetup({},t),y=v.context||v,m=v.context&&(y.nodeType||y.jquery)?k(y):k.event,x=k.Deferred(),b=k.Callbacks("once memory"),w=v.statusCode||{},a={},s={},u="canceled",T={readyState:0,getResponseHeader:function(e){var t;if(h){if(!n){n={};while(t=Pt.exec(p))n[t[1].toLowerCase()+" "]=(n[t[1].toLowerCase()+" "]||[]).concat(t[2])}t=n[e.toLowerCase()+" "]}return null==t?null:t.join(", ")},getAllResponseHeaders:function(){return h?p:null},setRequestHeader:function(e,t){return null==h&&(e=s[e.toLowerCase()]=s[e.toLowerCase()]||e,a[e]=t),this},overrideMimeType:function(e){return null==h&&(v.mimeType=e),this},statusCode:function(e){var t;if(e)if(h)T.always(e[T.status]);else for(t in e)w[t]=[w[t],e[t]];return this},abort:function(e){var t=e||u;return c&&c.abort(t),l(0,t),this}};if(x.promise(T),v.url=((e||v.url||Et.href)+"").replace(Mt,Et.protocol+"//"),v.type=t.method||t.type||v.method||v.type,v.dataTypes=(v.dataType||"*").toLowerCase().match(R)||[""],null==v.crossDomain){r=E.createElement("a");try{r.href=v.url,r.href=r.href,v.crossDomain=Ft.protocol+"//"+Ft.host!=r.protocol+"//"+r.host}catch(e){v.crossDomain=!0}}if(v.data&&v.processData&&"string"!=typeof v.data&&(v.data=k.param(v.data,v.traditional)),_t(It,v,t,T),h)return T;for(i in(g=k.event&&v.global)&&0==k.active++&&k.event.trigger("ajaxStart"),v.type=v.type.toUpperCase(),v.hasContent=!Rt.test(v.type),f=v.url.replace(Ht,""),v.hasContent?v.data&&v.processData&&0===(v.contentType||"").indexOf("application/x-www-form-urlencoded")&&(v.data=v.data.replace(Lt,"+")):(o=v.url.slice(f.length),v.data&&(v.processData||"string"==typeof v.data)&&(f+=(St.test(f)?"&":"?")+v.data,delete v.data),!1===v.cache&&(f=f.replace(Ot,"$1"),o=(St.test(f)?"&":"?")+"_="+kt+++o),v.url=f+o),v.ifModified&&(k.lastModified[f]&&T.setRequestHeader("If-Modified-Since",k.lastModified[f]),k.etag[f]&&T.setRequestHeader("If-None-Match",k.etag[f])),(v.data&&v.hasContent&&!1!==v.contentType||t.contentType)&&T.setRequestHeader("Content-Type",v.contentType),T.setRequestHeader("Accept",v.dataTypes[0]&&v.accepts[v.dataTypes[0]]?v.accepts[v.dataTypes[0]]+("*"!==v.dataTypes[0]?", "+$t+"; q=0.01":""):v.accepts["*"]),v.headers)T.setRequestHeader(i,v.headers[i]);if(v.beforeSend&&(!1===v.beforeSend.call(y,T,v)||h))return T.abort();if(u="abort",b.add(v.complete),T.done(v.success),T.fail(v.error),c=_t(Wt,v,t,T)){if(T.readyState=1,g&&m.trigger("ajaxSend",[T,v]),h)return T;v.async&&0<v.timeout&&(d=C.setTimeout(function(){T.abort("timeout")},v.timeout));try{h=!1,c.send(a,l)}catch(e){if(h)throw e;l(-1,e)}}else l(-1,"No Transport");function l(e,t,n,r){var i,o,a,s,u,l=t;h||(h=!0,d&&C.clearTimeout(d),c=void 0,p=r||"",T.readyState=0<e?4:0,i=200<=e&&e<300||304===e,n&&(s=function(e,t,n){var r,i,o,a,s=e.contents,u=e.dataTypes;while("*"===u[0])u.shift(),void 0===r&&(r=e.mimeType||t.getResponseHeader("Content-Type"));if(r)for(i in s)if(s[i]&&s[i].test(r)){u.unshift(i);break}if(u[0]in n)o=u[0];else{for(i in n){if(!u[0]||e.converters[i+" "+u[0]]){o=i;break}a||(a=i)}o=o||a}if(o)return o!==u[0]&&u.unshift(o),n[o]}(v,T,n)),s=function(e,t,n,r){var i,o,a,s,u,l={},c=e.dataTypes.slice();if(c[1])for(a in e.converters)l[a.toLowerCase()]=e.converters[a];o=c.shift();while(o)if(e.responseFields[o]&&(n[e.responseFields[o]]=t),!u&&r&&e.dataFilter&&(t=e.dataFilter(t,e.dataType)),u=o,o=c.shift())if("*"===o)o=u;else if("*"!==u&&u!==o){if(!(a=l[u+" "+o]||l["* "+o]))for(i in l)if((s=i.split(" "))[1]===o&&(a=l[u+" "+s[0]]||l["* "+s[0]])){!0===a?a=l[i]:!0!==l[i]&&(o=s[0],c.unshift(s[1]));break}if(!0!==a)if(a&&e["throws"])t=a(t);else try{t=a(t)}catch(e){return{state:"parsererror",error:a?e:"No conversion from "+u+" to "+o}}}return{state:"success",data:t}}(v,s,T,i),i?(v.ifModified&&((u=T.getResponseHeader("Last-Modified"))&&(k.lastModified[f]=u),(u=T.getResponseHeader("etag"))&&(k.etag[f]=u)),204===e||"HEAD"===v.type?l="nocontent":304===e?l="notmodified":(l=s.state,o=s.data,i=!(a=s.error))):(a=l,!e&&l||(l="error",e<0&&(e=0))),T.status=e,T.statusText=(t||l)+"",i?x.resolveWith(y,[o,l,T]):x.rejectWith(y,[T,l,a]),T.statusCode(w),w=void 0,g&&m.trigger(i?"ajaxSuccess":"ajaxError",[T,v,i?o:a]),b.fireWith(y,[T,l]),g&&(m.trigger("ajaxComplete",[T,v]),--k.active||k.event.trigger("ajaxStop")))}return T},getJSON:function(e,t,n){return k.get(e,t,n,"json")},getScript:function(e,t){return k.get(e,void 0,t,"script")}}),k.each(["get","post"],function(e,i){k[i]=function(e,t,n,r){return m(t)&&(r=r||n,n=t,t=void 0),k.ajax(k.extend({url:e,type:i,dataType:r,data:t,success:n},k.isPlainObject(e)&&e))}}),k._evalUrl=function(e,t){return k.ajax({url:e,type:"GET",dataType:"script",cache:!0,async:!1,global:!1,converters:{"text script":function(){}},dataFilter:function(e){k.globalEval(e,t)}})},k.fn.extend({wrapAll:function(e){var t;return this[0]&&(m(e)&&(e=e.call(this[0])),t=k(e,this[0].ownerDocument).eq(0).clone(!0),this[0].parentNode&&t.insertBefore(this[0]),t.map(function(){var e=this;while(e.firstElementChild)e=e.firstElementChild;return e}).append(this)),this},wrapInner:function(n){return m(n)?this.each(function(e){k(this).wrapInner(n.call(this,e))}):this.each(function(){var e=k(this),t=e.contents();t.length?t.wrapAll(n):e.append(n)})},wrap:function(t){var n=m(t);return this.each(function(e){k(this).wrapAll(n?t.call(this,e):t)})},unwrap:function(e){return this.parent(e).not("body").each(function(){k(this).replaceWith(this.childNodes)}),this}}),k.expr.pseudos.hidden=function(e){return!k.expr.pseudos.visible(e)},k.expr.pseudos.visible=function(e){return!!(e.offsetWidth||e.offsetHeight||e.getClientRects().length)},k.ajaxSettings.xhr=function(){try{return new C.XMLHttpRequest}catch(e){}};var Ut={0:200,1223:204},Xt=k.ajaxSettings.xhr();y.cors=!!Xt&&"withCredentials"in Xt,y.ajax=Xt=!!Xt,k.ajaxTransport(function(i){var o,a;if(y.cors||Xt&&!i.crossDomain)return{send:function(e,t){var n,r=i.xhr();if(r.open(i.type,i.url,i.async,i.username,i.password),i.xhrFields)for(n in i.xhrFields)r[n]=i.xhrFields[n];for(n in i.mimeType&&r.overrideMimeType&&r.overrideMimeType(i.mimeType),i.crossDomain||e["X-Requested-With"]||(e["X-Requested-With"]="XMLHttpRequest"),e)r.setRequestHeader(n,e[n]);o=function(e){return function(){o&&(o=a=r.onload=r.onerror=r.onabort=r.ontimeout=r.onreadystatechange=null,"abort"===e?r.abort():"error"===e?"number"!=typeof r.status?t(0,"error"):t(r.status,r.statusText):t(Ut[r.status]||r.status,r.statusText,"text"!==(r.responseType||"text")||"string"!=typeof r.responseText?{binary:r.response}:{text:r.responseText},r.getAllResponseHeaders()))}},r.onload=o(),a=r.onerror=r.ontimeout=o("error"),void 0!==r.onabort?r.onabort=a:r.onreadystatechange=function(){4===r.readyState&&C.setTimeout(function(){o&&a()})},o=o("abort");try{r.send(i.hasContent&&i.data||null)}catch(e){if(o)throw e}},abort:function(){o&&o()}}}),k.ajaxPrefilter(function(e){e.crossDomain&&(e.contents.script=!1)}),k.ajaxSetup({accepts:{script:"text/javascript, application/javascript, application/ecmascript, application/x-ecmascript"},contents:{script:/\b(?:java|ecma)script\b/},converters:{"text script":function(e){return k.globalEval(e),e}}}),k.ajaxPrefilter("script",function(e){void 0===e.cache&&(e.cache=!1),e.crossDomain&&(e.type="GET")}),k.ajaxTransport("script",function(n){var r,i;if(n.crossDomain||n.scriptAttrs)return{send:function(e,t){r=k("<script>").attr(n.scriptAttrs||{}).prop({charset:n.scriptCharset,src:n.url}).on("load error",i=function(e){r.remove(),i=null,e&&t("error"===e.type?404:200,e.type)}),E.head.appendChild(r[0])},abort:function(){i&&i()}}});var Vt,Gt=[],Yt=/(=)\?(?=&|$)|\?\?/;k.ajaxSetup({jsonp:"callback",jsonpCallback:function(){var e=Gt.pop()||k.expando+"_"+kt++;return this[e]=!0,e}}),k.ajaxPrefilter("json jsonp",function(e,t,n){var r,i,o,a=!1!==e.jsonp&&(Yt.test(e.url)?"url":"string"==typeof e.data&&0===(e.contentType||"").indexOf("application/x-www-form-urlencoded")&&Yt.test(e.data)&&"data");if(a||"jsonp"===e.dataTypes[0])return r=e.jsonpCallback=m(e.jsonpCallback)?e.jsonpCallback():e.jsonpCallback,a?e[a]=e[a].replace(Yt,"$1"+r):!1!==e.jsonp&&(e.url+=(St.test(e.url)?"&":"?")+e.jsonp+"="+r),e.converters["script json"]=function(){return o||k.error(r+" was not called"),o[0]},e.dataTypes[0]="json",i=C[r],C[r]=function(){o=arguments},n.always(function(){void 0===i?k(C).removeProp(r):C[r]=i,e[r]&&(e.jsonpCallback=t.jsonpCallback,Gt.push(r)),o&&m(i)&&i(o[0]),o=i=void 0}),"script"}),y.createHTMLDocument=((Vt=E.implementation.createHTMLDocument("").body).innerHTML="<form></form><form></form>",2===Vt.childNodes.length),k.parseHTML=function(e,t,n){return"string"!=typeof e?[]:("boolean"==typeof t&&(n=t,t=!1),t||(y.createHTMLDocument?((r=(t=E.implementation.createHTMLDocument("")).createElement("base")).href=E.location.href,t.head.appendChild(r)):t=E),o=!n&&[],(i=D.exec(e))?[t.createElement(i[1])]:(i=we([e],t,o),o&&o.length&&k(o).remove(),k.merge([],i.childNodes)));var r,i,o},k.fn.load=function(e,t,n){var r,i,o,a=this,s=e.indexOf(" ");return-1<s&&(r=mt(e.slice(s)),e=e.slice(0,s)),m(t)?(n=t,t=void 0):t&&"object"==typeof t&&(i="POST"),0<a.length&&k.ajax({url:e,type:i||"GET",dataType:"html",data:t}).done(function(e){o=arguments,a.html(r?k("<div>").append(k.parseHTML(e)).find(r):e)}).always(n&&function(e,t){a.each(function(){n.apply(this,o||[e.responseText,t,e])})}),this},k.each(["ajaxStart","ajaxStop","ajaxComplete","ajaxError","ajaxSuccess","ajaxSend"],function(e,t){k.fn[t]=function(e){return this.on(t,e)}}),k.expr.pseudos.animated=function(t){return k.grep(k.timers,function(e){return t===e.elem}).length},k.offset={setOffset:function(e,t,n){var r,i,o,a,s,u,l=k.css(e,"position"),c=k(e),f={};"static"===l&&(e.style.position="relative"),s=c.offset(),o=k.css(e,"top"),u=k.css(e,"left"),("absolute"===l||"fixed"===l)&&-1<(o+u).indexOf("auto")?(a=(r=c.position()).top,i=r.left):(a=parseFloat(o)||0,i=parseFloat(u)||0),m(t)&&(t=t.call(e,n,k.extend({},s))),null!=t.top&&(f.top=t.top-s.top+a),null!=t.left&&(f.left=t.left-s.left+i),"using"in t?t.using.call(e,f):c.css(f)}},k.fn.extend({offset:function(t){if(arguments.length)return void 0===t?this:this.each(function(e){k.offset.setOffset(this,t,e)});var e,n,r=this[0];return r?r.getClientRects().length?(e=r.getBoundingClientRect(),n=r.ownerDocument.defaultView,{top:e.top+n.pageYOffset,left:e.left+n.pageXOffset}):{top:0,left:0}:void 0},position:function(){if(this[0]){var e,t,n,r=this[0],i={top:0,left:0};if("fixed"===k.css(r,"position"))t=r.getBoundingClientRect();else{t=this.offset(),n=r.ownerDocument,e=r.offsetParent||n.documentElement;while(e&&(e===n.body||e===n.documentElement)&&"static"===k.css(e,"position"))e=e.parentNode;e&&e!==r&&1===e.nodeType&&((i=k(e).offset()).top+=k.css(e,"borderTopWidth",!0),i.left+=k.css(e,"borderLeftWidth",!0))}return{top:t.top-i.top-k.css(r,"marginTop",!0),left:t.left-i.left-k.css(r,"marginLeft",!0)}}},offsetParent:function(){return this.map(function(){var e=this.offsetParent;while(e&&"static"===k.css(e,"position"))e=e.offsetParent;return e||ie})}}),k.each({scrollLeft:"pageXOffset",scrollTop:"pageYOffset"},function(t,i){var o="pageYOffset"===i;k.fn[t]=function(e){return _(this,function(e,t,n){var r;if(x(e)?r=e:9===e.nodeType&&(r=e.defaultView),void 0===n)return r?r[i]:e[t];r?r.scrollTo(o?r.pageXOffset:n,o?n:r.pageYOffset):e[t]=n},t,e,arguments.length)}}),k.each(["top","left"],function(e,n){k.cssHooks[n]=ze(y.pixelPosition,function(e,t){if(t)return t=_e(e,n),$e.test(t)?k(e).position()[n]+"px":t})}),k.each({Height:"height",Width:"width"},function(a,s){k.each({padding:"inner"+a,content:s,"":"outer"+a},function(r,o){k.fn[o]=function(e,t){var n=arguments.length&&(r||"boolean"!=typeof e),i=r||(!0===e||!0===t?"margin":"border");return _(this,function(e,t,n){var r;return x(e)?0===o.indexOf("outer")?e["inner"+a]:e.document.documentElement["client"+a]:9===e.nodeType?(r=e.documentElement,Math.max(e.body["scroll"+a],r["scroll"+a],e.body["offset"+a],r["offset"+a],r["client"+a])):void 0===n?k.css(e,t,i):k.style(e,t,n,i)},s,n?e:void 0,n)}})}),k.each("blur focus focusin focusout resize scroll click dblclick mousedown mouseup mousemove mouseover mouseout mouseenter mouseleave change select submit keydown keypress keyup contextmenu".split(" "),function(e,n){k.fn[n]=function(e,t){return 0<arguments.length?this.on(n,null,e,t):this.trigger(n)}}),k.fn.extend({hover:function(e,t){return this.mouseenter(e).mouseleave(t||e)}}),k.fn.extend({bind:function(e,t,n){return this.on(e,null,t,n)},unbind:function(e,t){return this.off(e,null,t)},delegate:function(e,t,n,r){return this.on(t,e,n,r)},undelegate:function(e,t,n){return 1===arguments.length?this.off(e,"**"):this.off(t,e||"**",n)}}),k.proxy=function(e,t){var n,r,i;if("string"==typeof t&&(n=e[t],t=e,e=n),m(e))return r=s.call(arguments,2),(i=function(){return e.apply(t||this,r.concat(s.call(arguments)))}).guid=e.guid=e.guid||k.guid++,i},k.holdReady=function(e){e?k.readyWait++:k.ready(!0)},k.isArray=Array.isArray,k.parseJSON=JSON.parse,k.nodeName=A,k.isFunction=m,k.isWindow=x,k.camelCase=V,k.type=w,k.now=Date.now,k.isNumeric=function(e){var t=k.type(e);return("number"===t||"string"===t)&&!isNaN(e-parseFloat(e))},"function"==typeof define&&define.amd&&define("jquery",[],function(){return k});var Qt=C.jQuery,Jt=C.$;return k.noConflict=function(e){return C.$===k&&(C.$=Jt),e&&C.jQuery===k&&(C.jQuery=Qt),k},e||(C.jQuery=C.$=k),k}); diff --git a/synapse/static/client/register/js/register.js b/synapse/static/client/register/js/register.js index b62763a293..3547f7be4f 100644 --- a/synapse/static/client/register/js/register.js +++ b/synapse/static/client/register/js/register.js @@ -33,7 +33,7 @@ var setupCaptcha = function() { callback: Recaptcha.focus_response_field }); window.matrixRegistration.isUsingRecaptcha = true; - }).error(errorFunc); + }).fail(errorFunc); }; @@ -49,7 +49,7 @@ var submitCaptcha = function(user, pwd) { $.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) { console.log("Success -> "+JSON.stringify(response)); submitPassword(user, pwd, response.session); - }).error(function(err) { + }).fail(function(err) { Recaptcha.reload(); errorFunc(err); }); @@ -67,7 +67,7 @@ var submitPassword = function(user, pwd, session) { matrixRegistration.onRegistered( response.home_server, response.user_id, response.access_token ); - }).error(errorFunc); + }).fail(errorFunc); }; var errorFunc = function(err) { @@ -113,5 +113,5 @@ matrixRegistration.signUp = function() { matrixRegistration.onRegistered = function(hs_url, user_id, access_token) { // clobber this function - console.log("onRegistered - This function should be replaced to proceed."); + console.warn("onRegistered - This function should be replaced to proceed."); }; diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e7f6ea7286..ec89f645d4 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018,2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,518 +14,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar -import logging -import time - -from twisted.internet import defer - -from synapse.api.constants import PresenceState -from synapse.storage.devices import DeviceStore -from synapse.storage.user_erasure_store import UserErasureStore -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from .account_data import AccountDataStore -from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore -from .client_ips import ClientIpStore -from .deviceinbox import DeviceInboxStore -from .directory import DirectoryStore -from .e2e_room_keys import EndToEndRoomKeyStore -from .end_to_end_keys import EndToEndKeyStore -from .engines import PostgresEngine -from .event_federation import EventFederationStore -from .event_push_actions import EventPushActionsStore -from .events import EventsStore -from .events_bg_updates import EventsBackgroundUpdatesStore -from .filtering import FilteringStore -from .group_server import GroupServerStore -from .keys import KeyStore -from .media_repository import MediaRepositoryStore -from .monthly_active_users import MonthlyActiveUsersStore -from .openid import OpenIdStore -from .presence import PresenceStore, UserPresenceState -from .profile import ProfileStore -from .push_rule import PushRuleStore -from .pusher import PusherStore -from .receipts import ReceiptsStore -from .registration import RegistrationStore -from .rejections import RejectionsStore -from .relations import RelationsStore -from .room import RoomStore -from .roommember import RoomMemberStore -from .search import SearchStore -from .signatures import SignatureStore -from .state import StateStore -from .stats import StatsStore -from .stream import StreamStore -from .tags import TagsStore -from .transactions import TransactionStore -from .user_directory import UserDirectoryStore -from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator - -logger = logging.getLogger(__name__) - - -class DataStore( - EventsBackgroundUpdatesStore, - RoomMemberStore, - RoomStore, - RegistrationStore, - StreamStore, - ProfileStore, - PresenceStore, - TransactionStore, - DirectoryStore, - KeyStore, - StateStore, - SignatureStore, - ApplicationServiceStore, - EventsStore, - EventFederationStore, - MediaRepositoryStore, - RejectionsStore, - FilteringStore, - PusherStore, - PushRuleStore, - ApplicationServiceTransactionStore, - ReceiptsStore, - EndToEndKeyStore, - EndToEndRoomKeyStore, - SearchStore, - TagsStore, - AccountDataStore, - EventPushActionsStore, - OpenIdStore, - ClientIpStore, - DeviceStore, - DeviceInboxStore, - UserDirectoryStore, - GroupServerStore, - UserErasureStore, - MonthlyActiveUsersStore, - StatsStore, - RelationsStore, -): - def __init__(self, db_conn, hs): - self.hs = hs - self._clock = hs.get_clock() - self.database_engine = hs.database_engine - - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - self._presence_id_gen = StreamIdGenerator( - db_conn, "presence_stream", "stream_id" - ) - self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_max_stream_id", "stream_id" - ) - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) - self._device_list_id_gen = StreamIdGenerator( - db_conn, "device_lists_stream", "stream_id" - ) - - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") - self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) - self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id" - ) - - if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_invalidation_stream", "stream_id" - ) - else: - self._cache_id_gen = None - - self._presence_on_startup = self._get_active_presence(db_conn) - - presence_cache_prefill, min_presence_val = self._get_cache_dict( - db_conn, - "presence_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._presence_id_gen.get_current_token(), - ) - self.presence_stream_cache = StreamChangeCache( - "PresenceStreamChangeCache", - min_presence_val, - prefilled_cache=presence_cache_prefill, - ) - - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( - db_conn, - "device_inbox", - entity_column="user_id", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - min_device_inbox_id, - prefilled_cache=device_inbox_prefill, - ) - # The federation outbox and the local device inbox uses the same - # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( - db_conn, - "device_federation_outbox", - entity_column="destination", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - min_device_outbox_id, - prefilled_cache=device_outbox_prefill, - ) - - device_list_max = self._device_list_id_gen.get_current_token() - self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max - ) - self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max - ) - - events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, - "current_state_delta_stream", - entity_column="room_id", - stream_column="stream_id", - max_value=events_max, # As we share the stream id with events token - limit=1000, - ) - self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", - min_curr_state_delta_id, - prefilled_cache=curr_state_delta_prefill, - ) - - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( - db_conn, - "local_group_updates", - entity_column="user_id", - stream_column="stream_id", - max_value=self._group_updates_id_gen.get_current_token(), - limit=1000, - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - min_group_updates_id, - prefilled_cache=_group_updates_prefill, - ) - - self._stream_order_on_start = self.get_room_max_stream_ordering() - self._min_stream_order_on_start = self.get_room_min_stream_ordering() - - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - - super(DataStore, self).__init__(db_conn, hs) - - def take_presence_startup_info(self): - active_on_startup = self._presence_on_startup - self._presence_on_startup = None - return active_on_startup - - def _get_active_presence(self, db_conn): - """Fetch non-offline presence from the database so that we can register - the appropriate time outs. - """ - - sql = ( - "SELECT user_id, state, last_active_ts, last_federation_update_ts," - " last_user_sync_ts, status_msg, currently_active FROM presence_stream" - " WHERE state != ?" - ) - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.cursor_to_dict(txn) - txn.close() - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return [UserPresenceState(**row) for row in rows] - - def count_daily_users(self): - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday) - - def count_monthly_users(self): - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - count, = txn.fetchone() - return count - - def count_r30_users(self): - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns counts globaly for a given user as well as breaking - by platform - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - count, = txn.fetchone() - results["all"] = count - - return results - - return self.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - def generate_user_daily_visits(self): - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - return self.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - - def get_users(self): - """Function to reterive a list of users in users table. - - Args: - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self._simple_select_list( - table="users", - keyvalues={}, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - desc="get_users", - ) - - @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. - - Args: - order (str): column name to order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive - Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} - """ - users = yield self.runInteraction( - "get_users_paginate", - self._simple_select_list_paginate_txn, - table="users", - keyvalues={"is_guest": False}, - orderby=order, - start=start, - limit=limit, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - ) - count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) - retval = {"users": users, "total": count} - return retval - - def search_users(self, term): - """Function to search users list for one or more users with - the matched term. - - Args: - term (str): search term - col (str): column to query term should be matched to - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self._simple_search_list( - table="users", - term=term, - col="name", - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], - desc="search_users", - ) - - -def are_all_users_on_domain(txn, database_engine, domain): - sql = database_engine.convert_param_style( - "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" - ) - pat = "%:" + domain - txn.execute(sql, (pat,)) - num_not_matching = txn.fetchall()[0][0] - if num_not_matching == 0: - return True - return False +""" +The storage layer is split up into multiple parts to allow Synapse to run +against different configurations of databases (e.g. single or multiple +databases). The `Database` class represents a single physical database. The +`data_stores` are classes that talk directly to a `Database` instance and have +associated schemas, background updates, etc. On top of those there are classes +that provide high level interfaces that combine calls to multiple `data_stores`. + +There are also schemas that get applied to every database, regardless of the +data stores associated with them (e.g. the schema version tables), which are +stored in `synapse.storage.schema`. +""" + +from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main import DataStore +from synapse.storage.persist_events import EventsPersistenceStorage +from synapse.storage.purge_events import PurgeEventsStorage +from synapse.storage.state import StateGroupStorage + +__all__ = ["DataStores", "DataStore"] + + +class Storage(object): + """The high level interfaces for talking to various storage layers. + """ + + def __init__(self, hs, stores: DataStores): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.persistence = EventsPersistenceStorage(hs, stores) + self.purge_events = PurgeEventsStorage(hs, stores) + self.state = StateGroupStorage(hs, stores) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index abe16334ec..bfce541ca7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -14,1403 +14,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging import random -import sys -import threading -import time - -from six import PY2, iteritems, iterkeys, itervalues -from six.moves import builtins, intern, range +from abc import ABCMeta +from typing import Any, Optional from canonicaljson import json -from prometheus_client import Histogram - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.context import LoggingContext, PreserveLoggingContext -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import get_domain_from_id -from synapse.util import batch_iter -from synapse.util.caches.descriptors import Cache -from synapse.util.stringutils import exception_to_unicode -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time +from synapse.storage.database import LoggingTransaction # noqa: F401 +from synapse.storage.database import make_in_list_sql_clause # noqa: F401 +from synapse.storage.database import Database +from synapse.types import Collection, get_domain_from_id logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 - -sql_logger = logging.getLogger("synapse.storage.SQL") -transaction_logger = logging.getLogger("synapse.storage.txn") -perf_logger = logging.getLogger("synapse.storage.TIME") - -sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") - -sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) -sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) - - -# Unique indexes which have been added in background updates. Maps from table name -# to the name of the background update which added the unique index to that table. -# -# This is used by the upsert logic to figure out which tables are safe to do a proper -# UPSERT on: until the relevant background update has completed, we -# have to emulate an upsert by locking the table. -# -UNIQUE_INDEX_BACKGROUND_UPDATES = { - "user_ips": "user_ips_device_unique_index", - "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", - "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", - "event_search": "event_search_event_id_idx", -} - -# This is a special cache name we use to batch multiple invalidations of caches -# based on the current state when notifying workers over replication. -_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" - -class LoggingTransaction(object): - """An object that almost-transparently proxies for the 'txn' object - passed to the constructor. Adds logging and metrics to the .execute() - method. +# some of our subclasses have abstract methods, so we use the ABCMeta metaclass. +class SQLBaseStore(metaclass=ABCMeta): + """Base class for data stores that holds helper functions. - Args: - txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to - that have been added by `call_after` which should be run on - successful completion of the transaction. None indicates that no - callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended - to that have been added by `call_on_exception` which should be run - if transaction ends with an error. None indicates that no callbacks - should be allowed to be scheduled to run. + Note that multiple instances of this class will exist as there will be one + per data store (and not one per physical database). """ - __slots__ = [ - "txn", - "name", - "database_engine", - "after_callbacks", - "exception_callbacks", - ] - - def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None - ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) - - def call_after(self, callback, *args, **kwargs): - """Call the given callback on the main twisted thread after the - transaction has finished. Used to invalidate the caches on the - correct thread. - """ - self.after_callbacks.append((callback, args, kwargs)) - - def call_on_exception(self, callback, *args, **kwargs): - self.exception_callbacks.append((callback, args, kwargs)) - - def __getattr__(self, name): - return getattr(self.txn, name) - - def __setattr__(self, name, value): - setattr(self.txn, name, value) - - def __iter__(self): - return self.txn.__iter__() - - def execute_batch(self, sql, args): - if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch - - self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) - else: - for val in args: - self.execute(sql, val) - - def execute(self, sql, *args): - self._do_execute(self.txn.execute, sql, *args) - - def executemany(self, sql, *args): - self._do_execute(self.txn.executemany, sql, *args) - - def _make_sql_one_line(self, sql): - "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) - - def _do_execute(self, func, sql, *args): - sql = self._make_sql_one_line(sql) - - # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] {%s} %s", self.name, sql) - - sql = self.database_engine.convert_param_style(sql) - if args: - try: - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) - except Exception: - # Don't let logging failures stop SQL from working - pass - - start = time.time() - - try: - return func(sql, *args) - except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) - raise - finally: - secs = time.time() - start - sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) - sql_query_timer.labels(sql.split()[0]).observe(secs) - - -class PerformanceCounters(object): - def __init__(self): - self.current_counters = {} - self.previous_counters = {} - - def update(self, key, duration_secs): - count, cum_time = self.current_counters.get(key, (0, 0)) - count += 1 - cum_time += duration_secs - self.current_counters[key] = (count, cum_time) - - def interval(self, interval_duration_secs, limit=3): - counters = [] - for name, (count, cum_time) in iteritems(self.current_counters): - prev_count, prev_time = self.previous_counters.get(name, (0, 0)) - counters.append( - ( - (cum_time - prev_time) / interval_duration_secs, - count - prev_count, - name, - ) - ) - - self.previous_counters = dict(self.current_counters) - - counters.sort(reverse=True) - - top_n_counters = ", ".join( - "%s(%d): %.3f%%" % (name, count, 100 * ratio) - for ratio, count, name in counters[:limit] - ) - - return top_n_counters - - -class SQLBaseStore(object): - _TXN_ID = 0 - - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() - - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 - - # TODO(paul): These can eventually be removed once the metrics code - # is running in mainline, and we have some nice monitoring frontends - # to watch it - self._txn_perf_counters = PerformanceCounters() - - self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size - ) - - self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] - self._event_fetch_ongoing = 0 - - self._pending_ds = [] - - self.database_engine = hs.database_engine - - # A set of tables that are not safe to use native upserts in. - self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) - - self._account_validity = self.hs.config.account_validity - - # We add the user_directory_search table to the blacklist on SQLite - # because the existing search table does not have an index, making it - # unsafe to use native upserts. - if isinstance(self.database_engine, Sqlite3Engine): - self._unsafe_to_upsert_tables.add("user_directory_search") - - if self.database_engine.can_native_upsert: - # Check ASAP (and then later, every 1s) to see if we have finished - # background updates of tables that aren't safe to update. - self._clock.call_later( - 0.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - + self.database_engine = database.engine + self.db = database self.rand = random.SystemRandom() - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - - @defer.inlineCallbacks - def _check_safe_to_upsert(self): - """ - Is it safe to use native UPSERT? - - If there are background updates, we will need to wait, as they may be - the addition of indexes that set the UNIQUE constraint that we require. - - If the background updates have not completed, wait 15 sec and check again. - """ - updates = yield self._simple_select_list( - "background_updates", - keyvalues=None, - retcols=["update_name"], - desc="check_background_updates", - ) - updates = [x["update_name"] for x in updates] - - for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: - logger.debug("Now safe to upsert in %s", table) - self._unsafe_to_upsert_tables.discard(table) - - # If there's any updates still running, reschedule to run. - if updates: - self._clock.call_later( - 15.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - yield self.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self._simple_insert_txn( - txn, - "account_validity", - values={ - "user_id": user_id, - "expiration_ts_ms": expiration_ts, - "email_sent": False, - }, - ) - - def start_profiling(self): - self._previous_loop_ts = monotonic_time() - - def loop(): - curr = self._current_txn_total_time - prev = self._previous_txn_total_time - self._previous_txn_total_time = curr - - time_now = monotonic_time() - time_then = self._previous_loop_ts - self._previous_loop_ts = time_now - - duration = time_now - time_then - ratio = (curr - prev) / duration - - top_three_counters = self._txn_perf_counters.interval(duration, limit=3) - - perf_logger.info( - "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters - ) - - self._clock.looping_call(loop, 10000) - - def _new_transaction( - self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs - ): - start = monotonic_time() - txn_id = self._TXN_ID - - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) - - name = "%s-%x" % (desc, txn_id) - - transaction_logger.debug("[TXN START] {%s}", name) - - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - txn = LoggingTransaction( - txn, - name, - self.database_engine, - after_callbacks, - exception_callbacks, - ) - r = func(txn, *args, **kwargs) - conn.commit() - return r - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) - continue - raise - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), - ) - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = monotonic_time() - duration = end - start - - LoggingContext.current_context().add_database_transaction(duration) - - transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) - - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, duration) - sql_txn_timer.labels(desc).observe(duration) - - @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): - """Starts a transaction on the database and runs a given function - - Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a - database transaction (twisted.enterprise.adbapi.Transaction) as - its first argument, followed by `args` and `kwargs`. - - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - after_callbacks = [] - exception_callbacks = [] - - if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warn("Starting db txn '%s' from sentinel context", desc) - - try: - result = yield self.runWithConnection( - self._new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - **kwargs - ) - - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except: # noqa: E722, as we reraise the exception this is fine. - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise - - return result - - @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runWithConnection() method on the underlying db_pool. - - Arguments: - func (func): callback function, which will be called with a - database connection (twisted.enterprise.adbapi.Connection) as - its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - parent_context = LoggingContext.current_context() - if parent_context == LoggingContext.sentinel: - logger.warn( - "Starting db connection from sentinel context: metrics will be lost" - ) - parent_context = None - - start_time = monotonic_time() - - def inner_func(conn, *args, **kwargs): - with LoggingContext("runWithConnection", parent_context) as context: - sched_duration_sec = monotonic_time() - start_time - sql_scheduling_timer.observe(sched_duration_sec) - context.add_database_scheduled(sched_duration_sec) - - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - return func(conn, *args, **kwargs) - - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs) - - return result - - @staticmethod - def cursor_to_dict(cursor): - """Converts a SQL cursor into an list of dicts. - - Args: - cursor : The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) - return results - - def _execute(self, desc, decoder, query, *args): - """Runs a single query for a result set. - - Args: - decoder - The function which can resolve the cursor results to - something meaningful. - query - The query string to execute - *args - Query args. - Returns: - The result of decoder(results) - """ - - def interaction(txn): - txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() - - return self.runInteraction(desc, interaction) - - # "Simple" SQL API methods that operate on a single table with no JOINs, - # no complex WHERE clauses, just a dict of values for columns. - - @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): - """Executes an INSERT query on the named table. - - Args: - table : string giving the table name - values : dict of new column names and values for them - or_ignore : bool stating whether an exception should be raised - when a conflicting row already exists. If True, False will be - returned by the function instead - desc : string giving a description of the transaction - - Returns: - bool: Whether the row was inserted or not. Only useful when - `or_ignore` is True - """ - try: - yield self.runInteraction(desc, self._simple_insert_txn, table, values) - except self.database_engine.module.IntegrityError: - # We have to do or_ignore flag at this layer, since we can't reuse - # a cursor after we receive an error from the db. - if not or_ignore: - raise - return False - return True - - @staticmethod - def _simple_insert_txn(txn, table, values): - keys, vals = zip(*values.items()) - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys), - ", ".join("?" for _ in keys), - ) - - txn.execute(sql, vals) - - def _simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self._simple_insert_many_txn, table, values) - - @staticmethod - def _simple_insert_many_txn(txn, table, values): - if not values: - return - - # This is a *slight* abomination to get a list of tuples of key names - # and a list of tuples of value names. - # - # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] - # - # The sort is to ensure that we don't rely on dictionary iteration - # order. - keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] - ) - - for k in keys: - if k != keys[0]: - raise RuntimeError("All items must have the same keys") - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), - ) - - txn.executemany(sql, vals) - - @defer.inlineCallbacks - def _simple_upsert( - self, - table, - keyvalues, - values, - insertion_values={}, - desc="_simple_upsert", - lock=True, - ): - """ - - `lock` should generally be set to True (the default), but can be set - to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key columns and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - Deferred(None or bool): Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - attempts = 0 - while True: - try: - result = yield self.runInteraction( - desc, - self._simple_upsert_txn, - table, - keyvalues, - values, - insertion_values, - lock=lock, - ) - return result - except self.database_engine.module.IntegrityError as e: - attempts += 1 - if attempts >= 5: - # don't retry forever, because things other than races - # can cause IntegrityErrors - raise - - # presumably we raced with another transaction: let's retry. - logger.warn( - "IntegrityError when upserting into %s; retrying: %s", table, e - ) - - def _simple_upsert_txn( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Pick the UPSERT method which works best on the platform. Either the - native one (Pg9.5+, recent SQLites), or fall back to an emulated method. - - Args: - txn: The transaction to use. - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self._simple_upsert_txn_native_upsert( - txn, table, keyvalues, values, insertion_values=insertion_values - ) - else: - return self._simple_upsert_txn_emulated( - txn, - table, - keyvalues, - values, - insertion_values=insertion_values, - lock=lock, - ) - - def _simple_upsert_txn_emulated( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - bool: Return True if a new entry was created, False if an existing - one was updated. - """ - # We need to lock the table :(, unless we're *really* careful - if lock: - self.database_engine.lock_table(txn, table) - - def _getwhere(key): - # If the value we're passing in is None (aka NULL), we need to use - # IS, not =, as NULL = NULL equals NULL (False). - if keyvalues[key] is None: - return "%s IS ?" % (key,) - else: - return "%s = ?" % (key,) - - if not values: - # If `values` is empty, then all of the values we care about are in - # the unique key, so there is nothing to UPDATE. We can just do a - # SELECT instead to see if it exists. - sql = "SELECT 1 FROM %s WHERE %s" % ( - table, - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(keyvalues.values()) - txn.execute(sql, sqlargs) - if txn.fetchall(): - # We have an existing record. - return False - else: - # First try to update. - sql = "UPDATE %s SET %s WHERE %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(values.values()) + list(keyvalues.values()) - - txn.execute(sql, sqlargs) - if txn.rowcount > 0: - # successfully updated at least one row. - return False - - # We didn't find any existing rows, so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) - - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ) - txn.execute(sql, list(allvalues.values())) - # successfully inserted - return True - - def _simple_upsert_txn_native_upsert( - self, txn, table, keyvalues, values, insertion_values={} - ): - """ - Use the native UPSERT functionality in recent PostgreSQL versions. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - Returns: - None - """ - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(insertion_values) - - if not values: - latter = "NOTHING" - else: - allvalues.update(values) - latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - - sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ", ".join(k for k in keyvalues), - latter, - ) - txn.execute(sql, list(allvalues.values())) - - def _simple_upsert_many_txn( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self._simple_upsert_many_txn_native_upsert( - txn, table, key_names, key_values, value_names, value_values - ) - else: - return self._simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values - ) - - def _simple_upsert_many_txn_emulated( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, but without native UPSERT support or batching. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] - - for keyv, valv in zip(key_values, value_values): - _keys = {x: y for x, y in zip(key_names, keyv)} - _vals = {x: y for x, y in zip(value_names, valv)} - - self._simple_upsert_txn_emulated(txn, table, _keys, _vals) - - def _simple_upsert_many_txn_native_upsert( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, using batching where possible. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - allnames = [] - allnames.extend(key_names) - allnames.extend(value_names) - - if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. - latter = "NOTHING" - value_values = [() for x in range(len(key_values))] - else: - latter = "UPDATE SET " + ", ".join( - k + "=EXCLUDED." + k for k in value_names - ) - - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( - table, - ", ".join(k for k in allnames), - ", ".join("?" for _ in allnames), - ", ".join(key_names), - latter, - ) - - args = [] - - for x, y in zip(key_values, value_values): - args.append(tuple(x) + tuple(y)) - - return txn.execute_batch(sql, args) - - def _simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning multiple columns from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcols : list of strings giving the names of the columns to return - - allow_none : If true, return None instead of failing if the SELECT - statement returns no rows - """ - return self.runInteraction( - desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none - ) - - def _simple_select_one_onecol( - self, - table, - keyvalues, - retcol, - allow_none=False, - desc="_simple_select_one_onecol", - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcol : string giving the name of the column to return - """ - return self.runInteraction( - desc, - self._simple_select_one_onecol_txn, - table, - keyvalues, - retcol, - allow_none=allow_none, - ) - - @classmethod - def _simple_select_one_onecol_txn( - cls, txn, table, keyvalues, retcol, allow_none=False - ): - ret = cls._simple_select_onecol_txn( - txn, table=table, keyvalues=keyvalues, retcol=retcol - ) - - if ret: - return ret[0] - else: - if allow_none: - return None - else: - raise StoreError(404, "No row found") - - @staticmethod - def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} - - if keyvalues: - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - txn.execute(sql, list(keyvalues.values())) - else: - txn.execute(sql) - - return [r[0] for r in txn] - - def _simple_select_onecol( - self, table, keyvalues, retcol, desc="_simple_select_onecol" - ): - """Executes a SELECT query on the named table, which returns a list - comprising of the values of the named column from the selected rows. - - Args: - table (str): table name - keyvalues (dict|None): column names and values to select the rows with - retcol (str): column whos value we wish to retrieve. - - Returns: - Deferred: Results in a list - """ - return self.runInteraction( - desc, self._simple_select_onecol_txn, table, keyvalues, retcol - ) - - def _simple_select_list( - self, table, keyvalues, retcols, desc="_simple_select_list" - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, self._simple_select_list_txn, table, keyvalues, retcols - ) - - @classmethod - def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - """ - if keyvalues: - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - txn.execute(sql, list(keyvalues.values())) - else: - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - txn.execute(sql) - - return cls.cursor_to_dict(txn) - - @defer.inlineCallbacks - def _simple_select_many_batch( - self, - table, - column, - iterable, - retcols, - keyvalues={}, - desc="_simple_select_many_batch", - batch_size=100, - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - results = [] - - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: - rows = yield self.runInteraction( - desc, - self._simple_select_many_txn, - table, - column, - chunk, - keyvalues, - retcols, - ) - - results.extend(rows) - - return results - - @classmethod - def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - if not iterable: - return [] - - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - - clauses = [] - values = [] - clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) - values.extend(iterable) - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - - txn.execute(sql, values) - return cls.cursor_to_dict(txn) - - def _simple_update(self, table, keyvalues, updatevalues, desc): - return self.runInteraction( - desc, self._simple_update_txn, table, keyvalues, updatevalues - ) - - @staticmethod - def _simple_update_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) - - return txn.rowcount - - def _simple_update_one( - self, table, keyvalues, updatevalues, desc="_simple_update_one" - ): - """Executes an UPDATE query on the named table, setting new values for - columns in a row matching the key values. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - updatevalues : dict giving column names and values to update - retcols : optional list of column names to return - - If present, retcols gives a list of column names on which to perform - a SELECT statement *before* performing the UPDATE statement. The values - of these will be returned in a dict. - - These are performed within the same transaction, allowing an atomic - get-and-set. This can be used to implement compare-and-set by putting - the update column in the 'keyvalues' dict as well. - """ - return self.runInteraction( - desc, self._simple_update_one_txn, table, keyvalues, updatevalues - ) - - @classmethod - def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - - if rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(select_sql, list(keyvalues.values())) - row = txn.fetchone() - - if not row: - if allow_none: - return None - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - return dict(zip(retcols, row)) - - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) - - @staticmethod - def _simple_delete_one_txn(txn, table, keyvalues): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - if txn.rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) - - @staticmethod - def _simple_delete_txn(txn, table, keyvalues): - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - return txn.rowcount - - def _simple_delete_many(self, table, column, iterable, keyvalues, desc): - return self.runInteraction( - desc, self._simple_delete_many_txn, table, column, iterable, keyvalues - ) - - @staticmethod - def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): - """Executes a DELETE query on the named table. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - - Returns: - int: Number rows deleted - """ - if not iterable: - return 0 - - sql = "DELETE FROM %s" % table - - clauses = [] - values = [] - clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) - values.extend(iterable) - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - txn.execute(sql, values) - - return txn.rowcount - - def _get_cache_dict( - self, db_conn, table, entity_column, stream_column, max_value, limit=100000 - ): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - %(limit)s" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - "limit": limit, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - - cache = {row[0]: int(row[1]) for row in txn} - - txn.close() - - if cache: - min_val = min(itervalues(cache)) - else: - min_val = max_value - - return cache, min_val - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - txn.call_after(cache_func.invalidate, keys) - self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): - """Special case invalidation of caches based on current state. - - We special case this so that we can batch the cache invalidations into a - single replication poke. - - Args: - txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed - """ - txn.call_after(self._invalidate_state_caches, room_id, members_changed) - - if members_changed: - # We need to be careful that the size of the `members_changed` list - # isn't so large that it causes problems sending over replication, so we - # send them in chunks. - # Max line length is 16K, and max user ID length is 255, so 50 should - # be safe. - for chunk in batch_iter(members_changed, 50): - keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys - ) - else: - # if no members changed, we still need to invalidate the other caches. - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, [room_id] - ) + def process_replication_rows(self, stream_name, instance_name, token, rows): + pass def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does @@ -1421,7 +56,7 @@ class SQLBaseStore(object): members_changed (iterable[str]): The user_ids of members that have changed """ - for host in set(get_domain_from_id(u) for u in members_changed): + for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) @@ -1429,242 +64,29 @@ class SQLBaseStore(object): self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) - def _attempt_to_invalidate_cache(self, cache_name, key): + def _attempt_to_invalidate_cache( + self, cache_name: str, key: Optional[Collection[Any]] + ): """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. Args: - cache_name (str) - key (tuple) + cache_name + key: Entry to invalidate. If None then invalidates the entire + cache. """ + try: - getattr(self, cache_name).invalidate(key) + if key is None: + getattr(self, cache_name).invalidate_all() + else: + getattr(self, cache_name).invalidate(tuple(key)) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. pass - def _send_invalidation_to_replication(self, txn, cache_name, keys): - """Notifies replication that given cache has been invalidated. - - Note that this does *not* invalidate the cache locally. - - Args: - txn - cache_name (str) - keys (iterable[str]) - """ - - if isinstance(self.database_engine, PostgresEngine): - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) - txn.call_after(self.hs.get_notifier().on_new_replication_data) - - self._simple_insert_txn( - txn, - table="cache_invalidation_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_name, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - }, - ) - - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def _simple_select_list_paginate( - self, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - desc="_simple_select_list_paginate", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self._simple_select_list_paginate_txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction=order_direction, - ) - - @classmethod - def _simple_select_list_paginate_txn( - cls, - txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - if order_direction not in ["ASC", "DESC"]: - raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") - - if keyvalues: - where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) - else: - where_clause = "" - - sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( - ", ".join(retcols), - table, - where_clause, - orderby, - order_direction, - ) - txn.execute(sql, list(keyvalues.values()) + [limit, start]) - - return cls.cursor_to_dict(txn) - - def get_user_count_txn(self, txn): - """Get a total number of registered users in the users list. - - Args: - txn : Transaction object - Returns: - int : number of users - """ - sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" - txn.execute(sql_count) - return txn.fetchone()[0] - - def _simple_search_list( - self, table, term, col, retcols, desc="_simple_search_list" - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - - return self.runInteraction( - desc, self._simple_search_list_txn, table, term, col, retcols - ) - - @classmethod - def _simple_search_list_txn(cls, txn, table, term, col, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - if term: - sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) - termvalues = ["%%" + term + "%%"] - txn.execute(sql, termvalues) - else: - return 0 - - return cls.cursor_to_dict(txn) - - @property - def database_engine_name(self): - return self.database_engine.module.__name__ - - def get_server_version(self): - """Returns a string describing the server version number""" - return self.database_engine.server_version - - -class _RollbackButIsFineException(Exception): - """ This exception is used to rollback a transaction without implying - something went wrong. - """ - - pass - def db_to_json(db_content): """ @@ -1678,11 +100,6 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # psycopg2 on Python 2 returns buffer objects, which we need to cast to - # bytes to decode - if PY2 and isinstance(db_content, builtins.buffer): - db_content = bytes(db_content) - # Decode it to a Unicode string before feeding it to json.loads, so we # consistenty get a Unicode-containing object out. if isinstance(db_content, (bytes, bytearray)): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index e5f0668f09..59f3394b0a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from canonicaljson import json @@ -22,7 +23,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from . import engines -from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -74,7 +74,7 @@ class BackgroundUpdatePerformance(object): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdateStore(SQLBaseStore): +class BackgroundUpdater(object): """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to @@ -86,30 +86,34 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, db_conn, hs): - super(BackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, hs, database): + self._clock = hs.get_clock() + self.db = database + + # if a background update is currently running, its name. + self._current_background_update = None # type: Optional[str] + self._background_update_performance = {} - self._background_update_queue = [] self._background_update_handlers = {} self._all_done = False def start_doing_background_updates(self): - run_as_background_process("background_updates", self._run_background_updates) + run_as_background_process("background_updates", self.run_background_updates) - @defer.inlineCallbacks - def _run_background_updates(self): + async def run_background_updates(self, sleep=True): logger.info("Starting background schema updates") while True: - yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) + if sleep: + await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) try: - result = yield self.do_next_background_update( + result = await self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) except Exception: logger.exception("Error doing update") else: - if result is None: + if result: logger.info( "No more background updates to do." " Unscheduling background update task." @@ -117,30 +121,29 @@ class BackgroundUpdateStore(SQLBaseStore): self._all_done = True return None - @defer.inlineCallbacks - def has_completed_background_updates(self): + async def has_completed_background_updates(self) -> bool: """Check if all the background updates have completed Returns: - Deferred[bool]: True if all background updates have completed + True if all background updates have completed """ # if we've previously determined that there is nothing left to do, that # is easy if self._all_done: return True - # obviously, if we have things in our queue, we're not done. - if self._background_update_queue: + # obviously, if we are currently processing an update, we're not done. + if self._current_background_update: return False # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self._simple_select_onecol( + updates = await self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", - desc="check_background_updates", + desc="has_completed_background_updates", ) if not updates: self._all_done = True @@ -148,42 +151,79 @@ class BackgroundUpdateStore(SQLBaseStore): return False - @defer.inlineCallbacks - def do_next_background_update(self, desired_duration_ms): + async def has_completed_background_update(self, update_name) -> bool: + """Check if the given background update has finished running. + """ + if self._all_done: + return True + + if update_name == self._current_background_update: + return False + + update_exists = await self.db.simple_select_one_onecol( + "background_updates", + keyvalues={"update_name": update_name}, + retcol="1", + desc="has_completed_background_update", + allow_none=True, + ) + + return not update_exists + + async def do_next_background_update(self, desired_duration_ms: float) -> bool: """Does some amount of work on the next queued background update + Returns once some amount of work is done. + Args: desired_duration_ms(float): How long we want to spend updating. Returns: - A deferred that completes once some amount of work is done. - The deferred will have a value of None if there is currently - no more work to do. + True if we have finished running all the background updates, otherwise False """ - if not self._background_update_queue: - updates = yield self._simple_select_list( - "background_updates", - keyvalues=None, - retcols=("update_name", "depends_on"), + + def get_background_updates_txn(txn): + txn.execute( + """ + SELECT update_name, depends_on FROM background_updates + ORDER BY ordering, update_name + """ ) - in_flight = set(update["update_name"] for update in updates) - for update in updates: - if update["depends_on"] not in in_flight: - self._background_update_queue.append(update["update_name"]) + return self.db.cursor_to_dict(txn) - if not self._background_update_queue: - # no work left to do - return None + if not self._current_background_update: + all_pending_updates = await self.db.runInteraction( + "background_updates", get_background_updates_txn, + ) + if not all_pending_updates: + # no work left to do + return True + + # find the first update which isn't dependent on another one in the queue. + pending = {update["update_name"] for update in all_pending_updates} + for upd in all_pending_updates: + depends_on = upd["depends_on"] + if not depends_on or depends_on not in pending: + break + logger.info( + "Not starting on bg update %s until %s is done", + upd["update_name"], + depends_on, + ) + else: + # if we get to the end of that for loop, there is a problem + raise Exception( + "Unable to find a background update which doesn't depend on " + "another: dependency cycle?" + ) - # pop from the front, and add back to the back - update_name = self._background_update_queue.pop(0) - self._background_update_queue.append(update_name) + self._current_background_update = upd["update_name"] - res = yield self._do_background_update(update_name, desired_duration_ms) - return res + await self._do_background_update(desired_duration_ms) + return False - @defer.inlineCallbacks - def _do_background_update(self, update_name, desired_duration_ms): + async def _do_background_update(self, desired_duration_ms: float) -> int: + update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) update_handler = self._background_update_handlers[update_name] @@ -203,7 +243,7 @@ class BackgroundUpdateStore(SQLBaseStore): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self._simple_select_one_onecol( + progress_json = await self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -212,13 +252,13 @@ class BackgroundUpdateStore(SQLBaseStore): progress = json.loads(progress_json) time_start = self._clock.time_msec() - items_updated = yield update_handler(progress, batch_size) + items_updated = await update_handler(progress, batch_size) time_stop = self._clock.time_msec() duration_ms = time_stop - time_start logger.info( - "Updating %r. Updated %r items in %rms." + "Running background update %r. Processed %r items in %rms." " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)", update_name, items_updated, @@ -241,7 +281,9 @@ class BackgroundUpdateStore(SQLBaseStore): * A dict of the current progress * An integer count of the number of items to update in this batch. - The handler should return a deferred integer count of items updated. + The handler should return a deferred or coroutine which returns an integer count + of items updated. + The handler is responsible for updating the progress of the update. Args: @@ -357,7 +399,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.database_engine, engines.PostgresEngine): + if isinstance(self.db.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None @@ -368,33 +410,12 @@ class BackgroundUpdateStore(SQLBaseStore): def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.runWithConnection(runner) + yield self.db.runWithConnection(runner) yield self._end_background_update(update_name) return 1 self.register_background_update_handler(update_name, updater) - def start_background_update(self, update_name, progress): - """Starts a background update running. - - Args: - update_name: The update to set running. - progress: The initial state of the progress of the update. - - Returns: - A deferred that completes once the task has been added to the - queue. - """ - # Clear the background update queue so that we will pick up the new - # task on the next iteration of do_background_update. - self._background_update_queue = [] - progress_json = json.dumps(progress) - - return self._simple_insert( - "background_updates", - {"update_name": update_name, "progress_json": progress_json}, - ) - def _end_background_update(self, update_name): """Removes a completed background update task from the queue. @@ -403,13 +424,31 @@ class BackgroundUpdateStore(SQLBaseStore): Returns: A deferred that completes once the task is removed. """ - self._background_update_queue = [ - name for name in self._background_update_queue if name != update_name - ] - return self._simple_delete_one( + if update_name != self._current_background_update: + raise Exception( + "Cannot end background update %s which isn't currently running" + % update_name + ) + self._current_background_update = None + return self.db.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) + def _background_update_progress(self, update_name: str, progress: dict): + """Update the progress of a background update + + Args: + update_name: The name of the background update task + progress: The progress of the update. + """ + + return self.db.runInteraction( + "background_update_progress", + self._background_update_progress_txn, + update_name, + progress, + ) + def _background_update_progress_txn(self, txn, update_name, progress): """Update the progress of a background update @@ -421,7 +460,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = json.dumps(progress) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py new file mode 100644 index 0000000000..599ee470d4 --- /dev/null +++ b/synapse/storage/data_stores/__init__.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.data_stores.main.events import PersistEventsStore +from synapse.storage.data_stores.state import StateGroupDataStore +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database + +logger = logging.getLogger(__name__) + + +class DataStores(object): + """The various data stores. + + These are low level interfaces to physical databases. + + Attributes: + main (DataStore) + """ + + def __init__(self, main_store_class, hs): + # Note we pass in the main store class here as workers use a different main + # store. + + self.databases = [] + self.main = None + self.state = None + self.persist_events = None + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + + # Sanity check we don't try and configure the main store on + # multiple databases. + if self.main: + raise Exception("'main' data store already configured") + + self.main = main_store_class(database, db_conn, hs) + + # If we're on a process that can persist events also + # instantiate a `PersistEventsStore` + if hs.config.worker.writers.events == hs.get_instance_name(): + self.persist_events = PersistEventsStore( + hs, database, self.main + ) + + if "state" in database_config.data_stores: + logger.info("Starting 'state' data store") + + # Sanity check we don't try and configure the state store on + # multiple databases. + if self.state: + raise Exception("'state' data store already configured") + + self.state = StateGroupDataStore(database, db_conn, hs) + + db_conn.commit() + + self.databases.append(database) + + logger.info("Database %r prepared", db_name) + + # Sanity check that we have actually configured all the required stores. + if not self.main: + raise Exception("No 'main' data store configured") + + if not self.state: + raise Exception("No 'main' data store configured") diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py new file mode 100644 index 0000000000..4b4763c701 --- /dev/null +++ b/synapse/storage/data_stores/main/__init__.py @@ -0,0 +1,592 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import calendar +import logging +import time + +from synapse.api.constants import PresenceState +from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import ( + IdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from .account_data import AccountDataStore +from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .cache import CacheInvalidationWorkerStore +from .censor_events import CensorEventsStore +from .client_ips import ClientIpStore +from .deviceinbox import DeviceInboxStore +from .devices import DeviceStore +from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore +from .end_to_end_keys import EndToEndKeyStore +from .event_federation import EventFederationStore +from .event_push_actions import EventPushActionsStore +from .events_bg_updates import EventsBackgroundUpdatesStore +from .filtering import FilteringStore +from .group_server import GroupServerStore +from .keys import KeyStore +from .media_repository import MediaRepositoryStore +from .metrics import ServerMetricsStore +from .monthly_active_users import MonthlyActiveUsersStore +from .openid import OpenIdStore +from .presence import PresenceStore, UserPresenceState +from .profile import ProfileStore +from .purge_events import PurgeEventsStore +from .push_rule import PushRuleStore +from .pusher import PusherStore +from .receipts import ReceiptsStore +from .registration import RegistrationStore +from .rejections import RejectionsStore +from .relations import RelationsStore +from .room import RoomStore +from .roommember import RoomMemberStore +from .search import SearchStore +from .signatures import SignatureStore +from .state import StateStore +from .stats import StatsStore +from .stream import StreamStore +from .tags import TagsStore +from .transactions import TransactionStore +from .ui_auth import UIAuthStore +from .user_directory import UserDirectoryStore +from .user_erasure_store import UserErasureStore + +logger = logging.getLogger(__name__) + + +class DataStore( + EventsBackgroundUpdatesStore, + RoomMemberStore, + RoomStore, + RegistrationStore, + StreamStore, + ProfileStore, + PresenceStore, + TransactionStore, + DirectoryStore, + KeyStore, + StateStore, + SignatureStore, + ApplicationServiceStore, + PurgeEventsStore, + EventFederationStore, + MediaRepositoryStore, + RejectionsStore, + FilteringStore, + PusherStore, + PushRuleStore, + ApplicationServiceTransactionStore, + ReceiptsStore, + EndToEndKeyStore, + EndToEndRoomKeyStore, + SearchStore, + TagsStore, + AccountDataStore, + EventPushActionsStore, + OpenIdStore, + ClientIpStore, + DeviceStore, + DeviceInboxStore, + UserDirectoryStore, + GroupServerStore, + UserErasureStore, + MonthlyActiveUsersStore, + StatsStore, + RelationsStore, + CensorEventsStore, + UIAuthStore, + CacheInvalidationWorkerStore, + ServerMetricsStore, +): + def __init__(self, database: Database, db_conn, hs): + self.hs = hs + self._clock = hs.get_clock() + self.database_engine = database.engine + + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_max_stream_id", "stream_id" + ) + self._public_room_id_gen = StreamIdGenerator( + db_conn, "public_room_list_stream", "stream_id" + ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], + ) + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) + + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id" + ) + + if isinstance(self.database_engine, PostgresEngine): + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name="master", + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", + ) + else: + self._cache_id_gen = None + + super(DataStore, self).__init__(database, db_conn, hs) + + self._presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self.db.get_cache_dict( + db_conn, + "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_current_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", + min_presence_val, + prefilled_cache=presence_cache_prefill, + ) + + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( + db_conn, + "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( + db_conn, + "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max + ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", device_list_max + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max + ) + + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + db_conn, + "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, + ) + + _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( + db_conn, + "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", + min_group_updates_id, + prefilled_cache=_group_updates_prefill, + ) + + self._stream_order_on_start = self.get_room_max_stream_ordering() + self._min_stream_order_on_start = self.get_room_min_stream_ordering() + + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + + def take_presence_startup_info(self): + active_on_startup = self._presence_on_startup + self._presence_on_startup = None + return active_on_startup + + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.db.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + + def count_daily_users(self): + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return self.db.runInteraction("count_daily_users", self._count_users, yesterday) + + def count_monthly_users(self): + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return self.db.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return self.db.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + def generate_user_daily_visits(self): + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self.clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + return self.db.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) + + def get_users(self): + """Function to retrieve a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.db.simple_select_list( + table="users", + keyvalues={}, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "user_type", + "deactivated", + ], + desc="get_users", + ) + + def get_users_paginate( + self, start, limit, name=None, guests=True, deactivated=False + ): + """Function to retrieve a paginated list of users from + users list. This will return a json list of users and the + total number of users matching the filter criteria. + + Args: + start (int): start number to begin the query from + limit (int): number of rows to retrieve + name (string): filter for user names + guests (bool): whether to in include guest users + deactivated (bool): whether to include deactivated users + Returns: + defer.Deferred: resolves to list[dict[str, Any]], int + """ + + def get_users_paginate_txn(txn): + filters = [] + args = [] + + if name: + filters.append("name LIKE ?") + args.append("%" + name + "%") + + if not guests: + filters.append("is_guest = 0") + + if not deactivated: + filters.append("deactivated = 0") + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) + txn.execute(sql, args) + count = txn.fetchone()[0] + + args = [self.hs.config.server_name] + args + [limit, start] + sql = """ + SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + FROM users as u + LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? + {} + ORDER BY u.name LIMIT ? OFFSET ? + """.format( + where_clause + ) + txn.execute(sql, args) + users = self.db.cursor_to_dict(txn) + return users, count + + return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) + + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + col (str): column to query term should be matched to + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.db.simple_search_list( + table="users", + term=term, + col="name", + retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + desc="search_users", + ) + + +def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): + """Called before upgrading an existing database to check that it is broadly sane + compared with the configuration. + """ + domain = config.server_name + + sql = database_engine.convert_param_style( + "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" + ) + pat = "%:" + domain + cur.execute(sql, (pat,)) + num_not_matching = cur.fetchall()[0][0] + if num_not_matching == 0: + return + + raise Exception( + "Found users in database not native to %s!\n" + "You cannot changed a synapse server_name after it's been configured" + % (domain,) + ) + + +__all__ = ["DataStore", "check_database_before_upgrade"] diff --git a/synapse/storage/account_data.py b/synapse/storage/data_stores/main/account_data.py index 6afbfc0d74..b58f04d00d 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -16,12 +16,14 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,13 +40,13 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(db_conn, hs) + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -67,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_user_txn(txn): - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -92,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore): return global_account_data, by_room - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @@ -102,7 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: Deferred: A dict """ - result = yield self._simple_select_one_onecol( + result = yield self.db.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +129,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_txn(txn): - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -138,7 +140,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @@ -156,7 +158,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_and_type_txn(txn): - content_json = self._simple_select_one_onecol_txn( + content_json = self.db.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -170,45 +172,68 @@ class AccountDataWorkerStore(SQLBaseStore): return json.loads(content_json) if content_json else None - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) - def get_all_updated_account_data( - self, last_global_id, last_room_id, current_id, limit - ): - """Get all the client account_data that has changed on the server + async def get_updated_global_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: - last_global_id(int): The position to fetch from for top level data - last_room_id(int): The position to fetch from for per room data - current_id(int): The position to fetch up to. + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + Returns: - A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, type string, and content string. + A list of tuples of stream_id int, user_id string, + and type string. """ - if last_room_id == current_id and last_global_id == current_id: - return defer.succeed(([], [])) + if last_id == current_id: + return [] - def get_updated_account_data_txn(txn): + def get_updated_global_account_data_txn(txn): sql = ( - "SELECT stream_id, user_id, account_data_type, content" + "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_global_id, current_id, limit)) - global_results = txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return await self.db.runInteraction( + "get_updated_global_account_data", get_updated_global_account_data_txn + ) + + async def get_updated_room_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + room_id string and type string. + """ + if last_id == current_id: + return [] + + def get_updated_room_account_data_txn(txn): sql = ( - "SELECT stream_id, user_id, room_id, account_data_type, content" + "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_room_id, current_id, limit)) - room_results = txn.fetchall() - return global_results, room_results + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() - return self.runInteraction( - "get_all_updated_account_data_txn", get_updated_account_data_txn + return await self.db.runInteraction( + "get_updated_room_account_data", get_updated_room_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): @@ -250,9 +275,9 @@ class AccountDataWorkerStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - return {}, {} + return defer.succeed(({}, {})) - return self.runInteraction( + return self.db.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -270,12 +295,18 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" + db_conn, + "account_data_max_stream_id", + "stream_id", + extra_tables=[ + ("room_account_data", "stream_id"), + ("room_tags_revisions", "stream_id"), + ], ) - super(AccountDataStore, self).__init__(db_conn, hs) + super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream @@ -300,9 +331,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so _simple_upsert will + # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self._simple_upsert( + yield self.db.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -346,9 +377,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so _simple_upsert will retry if + # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self._simple_upsert( + yield self.db.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -362,6 +393,10 @@ class AccountDataStore(AccountDataWorkerStore): # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. + # + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. yield self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) @@ -380,6 +415,10 @@ class AccountDataStore(AccountDataWorkerStore): next_id(int): The the revision to advance to. """ + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. + def _update(txn): update_max_id_sql = ( "UPDATE account_data_max_stream_id" @@ -388,4 +427,4 @@ class AccountDataStore(AccountDataWorkerStore): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.runInteraction("update_account_data_max_stream_id", _update) + return self.db.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/appservice.py b/synapse/storage/data_stores/main/appservice.py index 435b2acd4d..7a1fe8cdd2 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -22,20 +22,20 @@ from twisted.internet import defer from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.events_worker import EventsWorkerStore - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database logger = logging.getLogger(__name__) def _make_exclusive_regex(services_cache): - # We precompie a regex constructed from all the regexes that the AS's + # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ regex.pattern for service in services_cache - for regex in service.get_exlusive_user_regexes() + for regex in service.get_exclusive_user_regexes() ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) @@ -49,13 +49,13 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache @@ -134,8 +134,8 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self._simple_select_list( - "application_services_state", dict(state=state), ["as_id"] + results = yield self.db.simple_select_list( + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -156,9 +156,9 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self._simple_select_one( + result = yield self.db.simple_select_one( "application_services_state", - dict(as_id=service.id), + {"as_id": service.id}, ["state"], allow_none=True, desc="get_appservice_state", @@ -176,8 +176,8 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves when the state was set successfully. """ - return self._simple_upsert( - "application_services_state", dict(as_id=service.id), dict(state=state) + return self.db.simple_upsert( + "application_services_state", {"as_id": service.id}, {"state": state} ) def create_appservice_txn(self, service, events): @@ -217,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.runInteraction("create_appservice_txn", _create_appservice_txn) + return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) def complete_appservice_txn(self, txn_id, service): """Completes an application service transaction. @@ -250,19 +250,23 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, "application_services_state", - dict(as_id=service.id), - dict(last_txn=txn_id), + {"as_id": service.id}, + {"last_txn": txn_id}, ) # Delete txn - self._simple_delete_txn( - txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) + self.db.simple_delete_txn( + txn, + "application_services_txns", + {"txn_id": txn_id, "as_id": service.id}, ) - return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) + return self.db.runInteraction( + "complete_appservice_txn", _complete_appservice_txn + ) @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): @@ -284,7 +288,7 @@ class ApplicationServiceTransactionWorkerStore( " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None @@ -292,7 +296,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.runInteraction( + entry = yield self.db.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -322,7 +326,7 @@ class ApplicationServiceTransactionWorkerStore( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.runInteraction( + return self.db.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) @@ -351,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py new file mode 100644 index 0000000000..eac5a4e55b --- /dev/null +++ b/synapse/storage/data_stores/main/cache.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging +from typing import Any, Iterable, Optional, Tuple + +from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationWorkerStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ): + """Fetches cache invalidation rows between the two given IDs written + by the given instance. Returns at most `limit` rows. + """ + + if last_id == current_id: + return [] + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) + return txn.fetchall() + + return await self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "events": + for row in rows: + self._process_event_stream_row(token, row) + elif stream_name == "backfill": + for row in rows: + self._invalidate_caches_for_event( + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, + backfilled=True, + ) + elif stream_name == "caches": + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) + + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self._invalidate_caches_for_event( + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def _invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): + self._invalidate_get_event_cache(event_id) + + self.get_latest_event_ids_in_room.invalidate((room_id,)) + + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + + if not backfilled: + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + + if redacts: + self._invalidate_get_event_cache(redacts) + + if etype == EventTypes.Member: + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) + + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.db.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_all_cache_and_stream(self, txn, cache_func): + """Invalidates the entire cache and adds it to the cache stream so slaves + will know to invalidate their caches. + """ + + txn.call_after(cache_func.invalidate_all) + self._send_invalidation_to_replication(txn, cache_func.__name__, None) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication( + self, txn, cache_name: str, keys: Optional[Iterable[Any]] + ): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name + keys: Entry to invalidate. If None will invalidate all. + """ + + if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: + raise Exception( + "Can't stream invalidate all with magic current state cache" + ) + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + stream_id = self._cache_id_gen.get_next_txn(txn) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + if keys is not None: + keys = list(keys) + + self.db.simple_insert_txn( + txn, + table="cache_invalidation_stream_by_instance", + values={ + "stream_id": stream_id, + "instance_name": self._instance_name, + "cache_func": cache_name, + "keys": keys, + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_cache_stream_token(self, instance_name): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token(instance_name) + else: + return 0 diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py new file mode 100644 index 0000000000..2d48261724 --- /dev/null +++ b/synapse/storage/data_stores/main/censor_events.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.internet import defer + +from synapse.events.utils import prune_event_dict +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore +from synapse.storage.data_stores.main.events import encode_json +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): + def __init__(self, database: Database, db_conn, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + def _censor_redactions(): + return run_as_background_process( + "_censor_redactions", self._censor_redactions + ) + + if self.hs.config.redaction_retention_period is not None: + hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + + async def _censor_redactions(self): + """Censors all redactions older than the configured period that haven't + been censored yet. + + By censor we mean update the event_json table with the redacted event. + """ + + if self.hs.config.redaction_retention_period is None: + return + + if not ( + await self.db.updates.has_completed_background_update( + "redactions_have_censored_ts_idx" + ) + ): + # We don't want to run this until the appropriate index has been + # created. + return + + before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + + # We fetch all redactions that: + # 1. point to an event we have, + # 2. has a received_ts from before the cut off, and + # 3. we haven't yet censored. + # + # This is limited to 100 events to ensure that we don't try and do too + # much at once. We'll get called again so this should eventually catch + # up. + sql = """ + SELECT redactions.event_id, redacts FROM redactions + LEFT JOIN events AS original_event ON ( + redacts = original_event.event_id + ) + WHERE NOT have_censored + AND redactions.received_ts <= ? + ORDER BY redactions.received_ts ASC + LIMIT ? + """ + + rows = await self.db.execute( + "_censor_redactions_fetch", None, sql, before_ts, 100 + ) + + updates = [] + + for redaction_id, event_id in rows: + redaction_event = await self.get_event(redaction_id, allow_none=True) + original_event = await self.get_event( + event_id, allow_rejected=True, allow_none=True + ) + + # The SQL above ensures that we have both the redaction and + # original event, so if the `get_event` calls return None it + # means that the redaction wasn't allowed. Either way we know that + # the result won't change so we mark the fact that we've checked. + if ( + redaction_event + and original_event + and original_event.internal_metadata.is_redacted() + ): + # Redaction was allowed + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) + else: + # Redaction wasn't allowed + pruned_json = None + + updates.append((redaction_id, event_id, pruned_json)) + + def _update_censor_txn(txn): + for redaction_id, event_id, pruned_json in updates: + if pruned_json: + self._censor_event_txn(txn, event_id, pruned_json) + + self.db.simple_update_one_txn( + txn, + table="redactions", + keyvalues={"event_id": redaction_id}, + updatevalues={"have_censored": True}, + ) + + await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self.db.simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self.db.simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) diff --git a/synapse/storage/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 6db8c54077..71f8d43a76 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -19,11 +19,10 @@ from six import iteritems from twisted.internet import defer -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.caches import CACHE_SIZE_FACTOR - -from . import background_updates -from ._base import Cache +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database, make_tuple_comparison_clause +from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -33,46 +32,41 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 -class ClientIpStore(background_updates.BackgroundUpdateStore): - def __init__(self, db_conn, hs): - - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR - ) - - super(ClientIpStore, self).__init__(db_conn, hs) +class ClientIpBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -81,18 +75,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ) # Drop the old non-unique index - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} - - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_reactor().addSystemEventTrigger( - "before", "shutdown", self._update_client_ips_batch + # Update the last seen info in devices. + self.db.updates.register_background_update_handler( + "devices_last_seen", self._devices_last_seen_update ) @defer.inlineCallbacks @@ -102,8 +91,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.runWithConnection(f) - yield self._end_background_update("user_ips_drop_nonunique_index") + yield self.db.runWithConnection(f) + yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") return 1 @defer.inlineCallbacks @@ -117,9 +106,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.runInteraction("user_ips_analyze", user_ips_analyze) + yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) - yield self._end_background_update("user_ips_analyze") + yield self.db.updates._end_background_update("user_ips_analyze") return 1 @@ -151,7 +140,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.runInteraction( + end_last_seen = yield self.db.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -282,18 +271,116 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.runInteraction("user_ips_dups_remove", remove) + yield self.db.runInteraction("user_ips_dups_remove", remove) if last: - yield self._end_background_update("user_ips_remove_dupes") + yield self.db.updates._end_background_update("user_ips_remove_dupes") return batch_size @defer.inlineCallbacks + def _devices_last_seen_update(self, progress, batch_size): + """Background update to insert last seen info into devices table + """ + + last_user_id = progress.get("last_user_id", "") + last_device_id = progress.get("last_device_id", "") + + def _devices_last_seen_update_txn(txn): + # This consists of two queries: + # + # 1. The sub-query searches for the next N devices and joins + # against user_ips to find the max last_seen associated with + # that device. + # 2. The outer query then joins again against user_ips on + # user/device/last_seen. This *should* hopefully only + # return one row, but if it does return more than one then + # we'll just end up updating the same device row multiple + # times, which is fine. + + where_clause, where_args = make_tuple_comparison_clause( + self.database_engine, + [("user_id", last_user_id), ("device_id", last_device_id)], + ) + + sql = """ + SELECT + last_seen, ip, user_agent, user_id, device_id + FROM ( + SELECT + user_id, device_id, MAX(u.last_seen) AS last_seen + FROM devices + INNER JOIN user_ips AS u USING (user_id, device_id) + WHERE %(where_clause)s + GROUP BY user_id, device_id + ORDER BY user_id ASC, device_id ASC + LIMIT ? + ) c + INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) + """ % { + "where_clause": where_clause + } + txn.execute(sql, where_args + [batch_size]) + + rows = txn.fetchall() + if not rows: + return 0 + + sql = """ + UPDATE devices + SET last_seen = ?, ip = ?, user_agent = ? + WHERE user_id = ? AND device_id = ? + """ + txn.execute_batch(sql, rows) + + _, _, _, user_id, device_id = rows[-1] + self.db.updates._background_update_progress_txn( + txn, + "devices_last_seen", + {"last_user_id": user_id, "last_device_id": device_id}, + ) + + return len(rows) + + updated = yield self.db.runInteraction( + "_devices_last_seen_update", _devices_last_seen_update_txn + ) + + if not updated: + yield self.db.updates._end_background_update("devices_last_seen") + + return updated + + +class ClientIpStore(ClientIpBackgroundUpdateStore): + def __init__(self, database: Database, db_conn, hs): + + self.client_ip_last_seen = Cache( + name="client_ip_last_seen", keylen=4, max_entries=50000 + ) + + super(ClientIpStore, self).__init__(database, db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) + self._batch_row_update = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) + + if self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @defer.inlineCallbacks def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): @@ -314,23 +401,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) + @wrap_as_background_process("update_client_ips") def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return - def update(): - to_update = self._batch_row_update - self._batch_row_update = {} - return self.runInteraction( - "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update - ) + to_update = self._batch_row_update + self._batch_row_update = {} - return run_as_background_process("update_client_ips", update) + return self.db.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update + ) def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self._unsafe_to_upsert_tables or ( + if "user_ips" in self.db._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): self.database_engine.lock_table(txn, "user_ips") @@ -339,7 +425,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -354,6 +440,23 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): }, lock=False, ) + + # Technically an access token might not be associated with + # a device so we need to check. + if device_id: + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( + txn, + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues={ + "user_agent": user_agent, + "last_seen": last_seen, + "ip": ip, + }, + ) except Exception as e: # Failed to upsert, log and continue logger.error("Failed to insert client IP %r: %r", entry, e) @@ -372,19 +475,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): keys giving the column names """ - res = yield self.runInteraction( - "get_last_client_ip_by_device", - self._get_last_client_ip_by_device_txn, - user_id, - device_id, - retcols=( - "user_id", - "access_token", - "ip", - "user_agent", - "device_id", - "last_seen", - ), + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + res = yield self.db.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), ) ret = {(d["user_id"], d["device_id"]): d for d in res} @@ -403,42 +501,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): } return ret - @classmethod - def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols): - where_clauses = [] - bindings = [] - if device_id is None: - where_clauses.append("user_id = ?") - bindings.extend((user_id,)) - else: - where_clauses.append("(user_id = ? AND device_id = ?)") - bindings.extend((user_id, device_id)) - - if not where_clauses: - return [] - - inner_select = ( - "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " - "WHERE %(where)s " - "GROUP BY user_id, device_id" - ) % {"where": " OR ".join(where_clauses)} - - sql = ( - "SELECT %(retcols)s FROM user_ips " - "JOIN (%(inner_select)s) ips ON" - " user_ips.last_seen = ips.mls AND" - " user_ips.user_id = ips.user_id AND" - " (user_ips.device_id = ips.device_id OR" - " (user_ips.device_id IS NULL AND ips.device_id IS NULL)" - " )" - ) % { - "retcols": ",".join("user_ips." + c for c in retcols), - "inner_select": inner_select, - } - - txn.execute(sql, bindings) - return cls.cursor_to_dict(txn) - @defer.inlineCallbacks def get_user_ip_and_agents(self, user): user_id = user.to_string() @@ -450,7 +512,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], @@ -461,7 +523,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) for row in rows ) - return list( + return [ { "access_token": access_token, "ip": ip, @@ -469,4 +531,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "last_seen": last_seen, } for (access_token, ip), (user_agent, last_seen) in iteritems(results) - ) + ] + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.db.updates.has_completed_background_update( + "devices_last_seen" + ): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ + + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 6b7458304e..9a1178fb39 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -20,8 +20,8 @@ from canonicaljson import json from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.runInteraction( + count = yield self.db.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -203,28 +203,92 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.runInteraction( + return self.db.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) -class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.db.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + + +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + yield self.db.runWithConnection(reindex_txn) + + yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) + # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( @@ -274,7 +338,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -294,7 +358,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self._simple_select_one_txn( + already_inserted = self.db.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -306,7 +370,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): # Add an entry for this message_id so that we know we've processed # it. - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -324,7 +388,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, @@ -347,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. - sql = "SELECT device_id FROM devices" " WHERE user_id = ?" + sql = "SELECT device_id FROM devices WHERE user_id = ?" txn.execute(sql, (user_id,)) message_json = json.dumps(messages_by_device["*"]) for row in txn: @@ -358,15 +422,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): else: if not devices: continue - sql = ( - "SELECT device_id FROM devices" - " WHERE user_id = ? AND device_id IN (" - + ",".join("?" * len(devices)) - + ")" + + clause, args = make_in_list_sql_clause( + txn.database_engine, "device_id", devices ) + sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause + # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. - txn.execute(sql, [user_id] + devices) + txn.execute(sql, [user_id] + list(args)) for row in txn: # Only insert into the local inbox if the device exists on # this server @@ -391,60 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) - - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") - txn.close() - - yield self.runWithConnection(reindex_txn) - - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) - - return 1 diff --git a/synapse/storage/devices.py b/synapse/storage/data_stores/main/devices.py index 79a58df591..fb9f798e29 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Optional, Set, Tuple from six import iteritems @@ -20,7 +23,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.api.errors import StoreError +from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( get_active_span_text_map, set_tag, @@ -28,10 +31,21 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import Cache, SQLBaseStore, db_to_json -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import ( + Database, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.util.caches.descriptors import ( + Cache, + cached, + cachedInlineCallbacks, + cachedList, +) +from synapse.util.iterutils import batch_iter +from synapse.util.stringutils import shortstr logger = logging.getLogger(__name__) @@ -39,10 +53,13 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( "drop_device_list_streams_non_unique_indexes" ) +BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" + class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): - """Retrieve a device. + """Retrieve a device. Only returns devices that are not marked as + hidden. Args: user_id (str): The ID of the user which owns the device @@ -52,16 +69,17 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self._simple_select_one( + return self.db.simple_select_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", ) @defer.inlineCallbacks def get_devices_by_user(self, user_id): - """Retrieve all of a user's registered devices. + """Retrieve all of a user's registered devices. Only returns devices + that are not marked as hidden. Args: user_id (str): @@ -70,9 +88,9 @@ class DeviceWorkerStore(SQLBaseStore): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self._simple_select_list( + devices = yield self.db.simple_select_list( table="devices", - keyvalues={"user_id": user_id}, + keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_devices_by_user", ) @@ -81,13 +99,18 @@ class DeviceWorkerStore(SQLBaseStore): @trace @defer.inlineCallbacks - def get_devices_by_remote(self, destination, from_stream_id, limit): - """Get stream of updates to send to remote servers + def get_device_updates_by_remote(self, destination, from_stream_id, limit): + """Get a stream of device updates to send to the given remote server. + Args: + destination (str): The host the device updates are intended for + from_stream_id (int): The minimum stream_id to filter updates by, exclusive + limit (int): Maximum number of device updates to return Returns: - Deferred[tuple[int, list[dict]]]: + Deferred[tuple[int, list[tuple[string,dict]]]]: current stream id (ie, the stream id of the last update included in the - response), and the list of updates + response), and the list of updates, where each update is a pair of EDU + type and EDU contents """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -97,36 +120,49 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - # We retrieve n+1 devices from the list of outbound pokes where n is - # our outbound device update limit. We then check if the very last - # device has the same stream_id as the second-to-last device. If so, - # then we ignore all devices with that stream_id and only send the - # devices with a lower stream_id. - # - # If when culling the list we end up with no devices afterwards, we - # consider the device update to be too large, and simply skip the - # stream_id; the rationale being that such a large device list update - # is likely an error. - updates = yield self.runInteraction( - "get_devices_by_remote", - self._get_devices_by_remote_txn, + updates = yield self.db.runInteraction( + "get_device_updates_by_remote", + self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, - limit + 1, + limit, ) # Return an empty list if there are no updates if not updates: return now_stream_id, [] - # if we have exceeded the limit, we need to exclude any results with the - # same stream_id as the last row. - if len(updates) > limit: - stream_id_cutoff = updates[-1][2] - now_stream_id = stream_id_cutoff - 1 - else: - stream_id_cutoff = None + # get the cross-signing keys of the users in the list, so that we can + # determine which of the device changes were cross-signing keys + users = {r[0] for r in updates} + master_key_by_user = {} + self_signing_key_by_user = {} + for user in users: + cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + master_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } + + cross_signing_key = yield self.get_e2e_cross_signing_key( + user, "self_signing" + ) + if cross_signing_key: + key_id, verify_key = get_verify_key_from_cross_signing_key( + cross_signing_key + ) + self_signing_key_by_user[user] = { + "key_info": cross_signing_key, + "device_id": verify_key.version, + } # Perform the equivalent of a GROUP BY # @@ -135,7 +171,6 @@ class DeviceWorkerStore(SQLBaseStore): # the max stream_id across each set of duplicate entries # # maps (user_id, device_id) -> (stream_id, opentracing_context) - # as long as their stream_id does not match that of the last row # # opentracing_context contains the opentracing metadata for the request # that created the poke @@ -144,39 +179,43 @@ class DeviceWorkerStore(SQLBaseStore): # context which created the Edu. query_map = {} - for update in updates: - if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: - # Stop processing updates - break - - key = (update[0], update[1]) - - update_context = update[3] - update_stream_id = update[2] - - previous_update_stream_id, _ = query_map.get(key, (0, None)) - - if update_stream_id > previous_update_stream_id: - query_map[key] = (update_stream_id, update_context) + cross_signing_keys_by_user = {} + for user_id, device_id, update_stream_id, update_context in updates: + if ( + user_id in master_key_by_user + and device_id == master_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["master_key"] = master_key_by_user[user_id]["key_info"] + elif ( + user_id in self_signing_key_by_user + and device_id == self_signing_key_by_user[user_id]["device_id"] + ): + result = cross_signing_keys_by_user.setdefault(user_id, {}) + result["self_signing_key"] = self_signing_key_by_user[user_id][ + "key_info" + ] + else: + key = (user_id, device_id) - # If we didn't find any updates with a stream_id lower than the cutoff, it - # means that there are more than limit updates all of which have the same - # steam_id. + previous_update_stream_id, _ = query_map.get(key, (0, None)) - # That should only happen if a client is spamming the server with new - # devices, in which case E2E isn't going to work well anyway. We'll just - # skip that stream_id and return an empty list, and continue with the next - # stream_id next time. - if not query_map: - return stream_id_cutoff, [] + if update_stream_id > previous_update_stream_id: + query_map[key] = (update_stream_id, update_context) results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) + # add the updated cross-signing keys to the results list + for user_id, result in iteritems(cross_signing_keys_by_user): + result["user_id"] = user_id + # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("org.matrix.signing_key_update", result)) + return now_stream_id, results - def _get_devices_by_remote_txn( + def _get_device_updates_by_remote_txn( self, txn, destination, from_stream_id, now_stream_id, limit ): """Return device update information for a given remote destination @@ -191,13 +230,14 @@ class DeviceWorkerStore(SQLBaseStore): Returns: List: List of device updates """ + # get the list of device updates that need to be sent sql = """ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes - WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + WHERE destination = ? AND ? < stream_id AND stream_id <= ? ORDER BY stream_id LIMIT ? """ - txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit)) + txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) return list(txn) @@ -216,12 +256,16 @@ class DeviceWorkerStore(SQLBaseStore): List[Dict]: List of objects representing an device update EDU """ - devices = yield self.runInteraction( - "_get_e2e_device_keys_txn", - self._get_e2e_device_keys_txn, - query_map.keys(), - include_all_devices=True, - include_deleted_devices=True, + devices = ( + yield self.db.runInteraction( + "_get_e2e_device_keys_txn", + self._get_e2e_device_keys_txn, + query_map.keys(), + include_all_devices=True, + include_deleted_devices=True, + ) + if query_map + else {} ) results = [] @@ -231,7 +275,14 @@ class DeviceWorkerStore(SQLBaseStore): prev_id = yield self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) - for device_id, device in iteritems(user_devices): + + # make sure we go through the devices in stream order + device_ids = sorted( + user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], + ) + + for device_id in device_ids: + device = user_devices[device_id] stream_id, opentracing_context = query_map[(user_id, device_id)] result = { "user_id": user_id, @@ -247,13 +298,20 @@ class DeviceWorkerStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name else: result["deleted"] = True - results.append(result) + results.append(("m.device_list_update", result)) return results @@ -270,12 +328,12 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.runInteraction("get_last_device_update_for_remote_user", f) + return self.db.runInteraction("get_last_device_update_for_remote_user", f) def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. """ - return self.runInteraction( + return self.db.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -284,32 +342,23 @@ class DeviceWorkerStore(SQLBaseStore): def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): # We update the device_lists_outbound_last_success with the successfully - # poked users. We do the join to see which users need to be inserted and - # which updated. + # poked users. sql = """ - SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + SELECT user_id, coalesce(max(o.stream_id), 0) FROM device_lists_outbound_pokes as o - LEFT JOIN device_lists_outbound_last_success as s - USING (destination, user_id) WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id """ txn.execute(sql, (destination, stream_id)) rows = txn.fetchall() - sql = """ - UPDATE device_lists_outbound_last_success - SET stream_id = ? - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2])) - - sql = """ - INSERT INTO device_lists_outbound_last_success - (destination, user_id, stream_id) VALUES (?, ?, ?) - """ - txn.executemany( - sql, ((destination, row[0], row[1]) for row in rows if not row[2]) + self.db.simple_upsert_many_txn( + txn=txn, + table="device_lists_outbound_last_success", + key_names=("destination", "user_id"), + key_values=((destination, user_id) for user_id, _ in rows), + value_names=("stream_id",), + value_values=((stream_id,) for _, stream_id in rows), ) # Delete all sent outbound pokes @@ -319,6 +368,41 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, stream_id)) + @defer.inlineCallbacks + def add_user_signature_change_to_streams(self, from_user_id, user_ids): + """Persist that a user has made new signatures + + Args: + from_user_id (str): the user who made the signatures + user_ids (list[str]): the users who were signed + """ + + with self._device_list_id_gen.get_next() as stream_id: + yield self.db.runInteraction( + "add_user_sig_change_to_streams", + self._add_user_signature_change_txn, + from_user_id, + user_ids, + stream_id, + ) + return stream_id + + def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + txn.call_after( + self._user_signature_stream_cache.entity_has_changed, + from_user_id, + stream_id, + ) + self.db.simple_insert_txn( + txn, + "user_signature_stream", + values={ + "stream_id": stream_id, + "from_user_id": from_user_id, + "user_ids": json.dumps(user_ids), + }, + ) + def get_device_stream_token(self): return self._device_list_id_gen.get_current_token() @@ -336,11 +420,17 @@ class DeviceWorkerStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - user_ids = set(user_id for user_id, _ in query_list) + user_ids = {user_id for user_id, _ in query_list} user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) - user_ids_in_cache = set( - user_id for user_id, stream_id in user_map.items() if stream_id + + # We go and check if any of the users need to have their device lists + # resynced. If they do then we remove them from the cached list. + users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( + user_ids ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -352,7 +442,7 @@ class DeviceWorkerStore(SQLBaseStore): device = yield self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device else: - results[user_id] = yield self._get_cached_devices_for_user(user_id) + results[user_id] = yield self.get_cached_devices_for_user(user_id) set_tag("in_cache", results) set_tag("not_in_cache", user_ids_not_in_cache) @@ -361,7 +451,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( + content = yield self.db.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -370,12 +460,12 @@ class DeviceWorkerStore(SQLBaseStore): return db_to_json(content) @cachedInlineCallbacks() - def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( + def get_cached_devices_for_user(self, user_id): + devices = yield self.db.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), - desc="_get_cached_devices_for_user", + desc="get_cached_devices_for_user", ) return { device["device_id"]: db_to_json(device["content"]) for device in devices @@ -387,7 +477,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: (stream_id, devices) """ - return self.runInteraction( + return self.db.runInteraction( "get_devices_with_keys_by_user", self._get_devices_with_keys_by_user_txn, user_id, @@ -409,6 +499,13 @@ class DeviceWorkerStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -435,8 +532,8 @@ class DeviceWorkerStore(SQLBaseStore): # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = list( - self._device_list_stream_cache.get_entities_changed(user_ids, from_key) + to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key ) if not to_check: @@ -448,35 +545,71 @@ class DeviceWorkerStore(SQLBaseStore): sql = """ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? - AND user_id IN (%s) + AND """ for chunk in batch_iter(to_check, 100): - txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk) + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql + clause, (from_key,) + tuple(args)) changes.update(user_id for user_id, in txn) return changes - return self.runInteraction( + return self.db.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. + @defer.inlineCallbacks + def get_users_whose_signatures_changed(self, user_id, from_key): + """Get the users who have new cross-signing signatures made by `user_id` since + `from_key`. + + Args: + user_id (str): the user who made the signatures + from_key (str): The device lists stream token """ - # We do a group by here as there can be a large number of duplicate - # entries, since we throw away device IDs. + from_key = int(from_key) + if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): + sql = """ + SELECT DISTINCT user_ids FROM user_signature_stream + WHERE from_user_id = ? AND stream_id > ? + """ + rows = yield self.db.execute( + "get_users_whose_signatures_changed", None, sql, user_id, from_key + ) + return {user for row in rows for user in json.loads(row[0])} + else: + return set() + + async def get_all_device_list_changes_for_remotes( + self, from_key: int, to_key: int, limit: int, + ) -> List[Tuple[int, str]]: + """Return a list of `(stream_id, entity)` which is the combined list of + changes to devices and which destinations need to be poked. Entity is + either a user ID (starting with '@') or a remote destination. + """ + + # This query Does The Right Thing where it'll correctly apply the + # bounds to the inner queries. sql = """ - SELECT MAX(stream_id) AS stream_id, user_id, destination - FROM device_lists_stream - LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream + UNION ALL + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + ) AS e WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id, destination + LIMIT ? """ - return self._execute( - "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key + + return await self.db.execute( + "get_all_device_list_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, ) @cached(max_entries=10000) @@ -484,7 +617,7 @@ class DeviceWorkerStore(SQLBaseStore): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -498,7 +631,7 @@ class DeviceWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -511,20 +644,72 @@ class DeviceWorkerStore(SQLBaseStore): return results + @defer.inlineCallbacks + def get_user_ids_requiring_device_list_resync( + self, user_ids: Optional[Collection[str]] = None, + ) -> Set[str]: + """Given a list of remote users return the list of users that we + should resync the device lists for. If None is given instead of a list, + return every user that we should resync the device lists for. -class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) + Returns: + The IDs of users whose device lists need resync. + """ + if user_ids: + rows = yield self.db.simple_select_many_batch( + table="device_lists_remote_resync", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync_with_iterable", + ) + else: + rows = yield self.db.simple_select_list( + table="device_lists_remote_resync", + keyvalues=None, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync", + ) - # Map of (user_id, device_id) -> bool. If there is an entry that implies - # the device exists. - self.device_id_exists_cache = Cache( - name="device_id_exists", keylen=2, max_entries=10000 + return {row["user_id"] for row in rows} + + def mark_remote_user_device_cache_as_stale(self, user_id: str): + """Records that the server has reason to believe the cache of the devices + for the remote users is out of date. + """ + return self.db.simple_upsert( + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, + values={}, + insertion_values={"added_ts": self._clock.time_msec()}, + desc="make_remote_user_device_cache_as_stale", ) - self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ - self.register_background_index_update( + def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + self.db.simple_delete_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_device_list_last_stream_id_for_remote, (user_id,) + ) + + return self.db.runInteraction( + "mark_remote_user_device_list_as_unsubscribed", + _mark_remote_user_device_list_as_unsubscribed_txn, + ) + + +class DeviceBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.db.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -532,7 +717,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) # create a unique index on device_lists_remote_cache - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -541,7 +726,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) # And one on device_lists_remote_extremeties - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -550,11 +735,112 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) + # clear out duplicate device list outbound pokes + self.db.updates.register_background_update_handler( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, + ) + + # a pair of background updates that were added during the 1.14 release cycle, + # but replaced with 58/06dlols_unique_idx.py + self.db.updates.register_noop_background_update( + "device_lists_outbound_last_success_unique_idx", + ) + self.db.updates.register_noop_background_update( + "drop_device_lists_outbound_last_success_non_unique_idx", + ) + + @defer.inlineCallbacks + def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") + txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") + txn.close() + + yield self.db.runWithConnection(f) + yield self.db.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) + return 1 + + async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + # for some reason, we have accumulated duplicate entries in + # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + # efficient. + # + # For each duplicate, we delete all the existing rows and put one back. + + KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] + last_row = progress.get( + "last_row", + {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, + ) + + def _txn(txn): + clause, args = make_tuple_comparison_clause( + self.db.engine, [(x, last_row[x]) for x in KEY_COLS] + ) + sql = """ + SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts + FROM device_lists_outbound_pokes + WHERE %s + GROUP BY %s + HAVING count(*) > 1 + ORDER BY %s + LIMIT ? + """ % ( + clause, # WHERE + ",".join(KEY_COLS), # GROUP BY + ",".join(KEY_COLS), # ORDER BY + ) + txn.execute(sql, args + [batch_size]) + rows = self.db.cursor_to_dict(txn) + + row = None + for row in rows: + self.db.simple_delete_txn( + txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, + ) + + row["sent"] = False + self.db.simple_insert_txn( + txn, "device_lists_outbound_pokes", row, + ) + + if row: + self.db.updates._background_update_progress_txn( + txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, + ) + + return len(rows) + + rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) + + if not rows: + await self.db.updates._end_background_update( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES + ) + + return rows + + +class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + def __init__(self, database: Database, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) + + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", keylen=2, max_entries=10000 + ) + + self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): """Ensure the given device is known; add it to the store if not @@ -567,24 +853,39 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred: boolean whether the device was inserted or an existing device existed with that ID. + Raises: + StoreError: if the device is already in use """ key = (user_id, device_id) if self.device_id_exists_cache.get(key, None): return False try: - inserted = yield self._simple_insert( + inserted = yield self.db.simple_insert( "devices", values={ "user_id": user_id, "device_id": device_id, "display_name": initial_device_display_name, + "hidden": False, }, desc="store_device", or_ignore=True, ) + if not inserted: + # if the device already exists, check if it's a real device, or + # if the device ID is reserved by something else + hidden = yield self.db.simple_select_one_onecol( + "devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="hidden", + ) + if hidden: + raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) self.device_id_exists_cache.prefill(key, True) return inserted + except StoreError: + raise except Exception as e: logger.error( "store_device with device_id=%s(%r) user_id=%s(%r)" @@ -609,9 +910,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_one( + yield self.db.simple_delete_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", ) @@ -627,18 +928,19 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_many( + yield self.db.simple_delete_many( table="devices", column="device_id", iterable=device_ids, - keyvalues={"user_id": user_id}, + keyvalues={"user_id": user_id, "hidden": False}, desc="delete_devices", ) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) def update_device(self, user_id, device_id, new_display_name=None): - """Update a device. + """Update a device. Only updates the device if it is not marked as + hidden. Args: user_id (str): The ID of the user which owns the device @@ -655,24 +957,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self._simple_update_one( + return self.db.simple_update_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, desc="update_device", ) - @defer.inlineCallbacks - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. - """ - yield self._simple_delete( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - desc="mark_remote_user_device_list_as_unsubscribed", - ) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - def update_remote_device_list_cache_entry( self, user_id, device_id, content, stream_id ): @@ -690,7 +981,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -703,7 +994,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -711,7 +1002,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -722,12 +1013,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) - txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -751,7 +1042,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -760,11 +1051,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -777,13 +1068,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ], ) - txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -793,34 +1084,61 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): lock=False, ) + # If we're replacing the remote user's device list cache presumably + # we've done a full resync, so we remove the entry that says we need + # to resync + self.db.simple_delete_txn( + txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, + ) + @defer.inlineCallbacks def add_device_change_to_streams(self, user_id, device_ids, hosts): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ - with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( - "add_device_change_to_streams", - self._add_device_change_txn, + if not device_ids: + return + + with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + yield self.db.runInteraction( + "add_device_change_to_stream", + self._add_device_change_to_stream_txn, + user_id, + device_ids, + stream_ids, + ) + + if not hosts: + return stream_ids[-1] + + context = get_active_span_text_map() + with self._device_list_id_gen.get_next_mult( + len(hosts) * len(device_ids) + ) as stream_ids: + yield self.db.runInteraction( + "add_device_outbound_poke_to_stream", + self._add_device_outbound_poke_to_stream_txn, user_id, device_ids, hosts, - stream_id, + stream_ids, + context, ) - return stream_id - def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): - now = self._clock.time_msec() + return stream_ids[-1] + def _add_device_change_to_stream_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + stream_ids: List[str], + ): txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_id + self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], ) - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_id, - ) + + min_stream_id = stream_ids[0] # Delete older entries in the table, as we really only care about # when the latest change happened. @@ -829,27 +1147,38 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_id) for device_id in device_ids], + [(user_id, device_id, min_stream_id) for device_id in device_ids], ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_stream", values=[ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} - for device_id in device_ids + for stream_id, device_id in zip(stream_ids, device_ids) ], ) - context = get_active_span_text_map() + def _add_device_outbound_poke_to_stream_txn( + self, txn, user_id, device_ids, hosts, stream_ids, context, + ): + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) - self._simple_insert_many_txn( + now = self._clock.time_msec() + next_stream_id = iter(stream_ids) + + self.db.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ { "destination": destination, - "stream_id": stream_id, + "stream_id": next(next_stream_id), "user_id": user_id, "device_id": device_id, "sent": False, @@ -863,18 +1192,47 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self): + def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. We keep one entry per - (destination, user_id) tuple to ensure that the prev_ids remain correct - if the server does come back. + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. """ - yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + yesterday = self._clock.time_msec() - prune_age def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. select_sql = """ - SELECT destination, user_id, max(stream_id) as stream_id - FROM device_lists_outbound_pokes + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 GROUP BY destination, user_id HAVING min(ts) < ? AND count(*) > 1 """ @@ -885,14 +1243,29 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): if not rows: return + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. delete_sql = """ DELETE FROM device_lists_outbound_pokes - WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) """ - - txn.executemany( - delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) - ) + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount # Since we've deleted unsent deltas, we need to remove the entry # of last successful sent so that the prev_ids are correctly set. @@ -902,23 +1275,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ txn.executemany(sql, ((row[0], row[1]) for row in rows)) - logger.info("Pruned %d device list outbound pokes", txn.rowcount) + logger.info("Pruned %d device list outbound pokes", count) return run_as_background_process( "prune_old_outbound_device_pokes", - self.runInteraction, + self.db.runInteraction, "_prune_old_outbound_device_pokes", _prune_txn, ) - - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): - def f(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") - txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") - txn.close() - - yield self.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) - return 1 diff --git a/synapse/storage/directory.py b/synapse/storage/data_stores/main/directory.py index eed7757ed5..e1d1bc3e05 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -14,14 +14,14 @@ # limitations under the License. from collections import namedtuple +from typing import Optional from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore - RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) @@ -37,7 +37,7 @@ class DirectoryWorkerStore(SQLBaseStore): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self._simple_select_one_onecol( + room_id = yield self.db.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -48,7 +48,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self._simple_select_onecol( + servers = yield self.db.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -61,7 +61,7 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -70,7 +70,7 @@ class DirectoryWorkerStore(SQLBaseStore): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -94,7 +94,7 @@ class DirectoryStore(DirectoryWorkerStore): """ def alias_txn(txn): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "room_aliases", { @@ -104,7 +104,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_alias_servers", values=[ @@ -118,7 +118,9 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.runInteraction("create_room_alias_association", alias_txn) + ret = yield self.db.runInteraction( + "create_room_alias_association", alias_txn + ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() @@ -127,7 +129,7 @@ class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def delete_room_alias(self, room_alias): - room_id = yield self.runInteraction( + room_id = yield self.db.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) @@ -158,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room(self, old_room_id, new_room_id, creator): + def update_aliases_for_room( + self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, + ): + """Repoint all of the aliases for a given room, to a different room. + + Args: + old_room_id: + new_room_id: + creator: The user to record as the creator of the new mapping. + If None, the creator will be left unchanged. + """ + def _update_aliases_for_room_txn(txn): - sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" - txn.execute(sql, (new_room_id, creator, old_room_id)) + update_creator_sql = "" + sql_params = (new_room_id, old_room_id) + if creator: + update_creator_sql = ", creator = ?" + sql_params = (new_room_id, creator, old_room_id) + + sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % ( + update_creator_sql, + ) + txn.execute(sql, sql_params) self._invalidate_cache_and_stream( txn, self.get_aliases_for_room, (old_room_id,) ) @@ -169,6 +190,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.runInteraction( + return self.db.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index be2fe2bab6..23f4570c4b 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,55 +20,13 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup Args: user_id(str): the user whose backup we're setting @@ -79,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_upsert( + yield self.db.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -87,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore): "room_id": room_id, "session_id": session_id, }, - values={ + updatevalues={ "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], "session_data": json.dumps(room_key["session_data"]), }, - lock=False, + desc="update_e2e_room_key", ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id (str): the user whose backup we're adding to + version (str): the version ID of the backup for the set of keys we're adding to + room_keys (iterable[(str, str, dict)]): the keys to add, in the form + (roomID, sessionID, keyData) + """ + + values = [] + for (room_id, session_id, room_key) in room_keys: + values.append( + { + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + yield self.db.simple_insert_many( + table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace @@ -111,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore): room, or a given session. Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): Optional. the ID of the room whose keys we're querying, if any. + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup for the set of keys we're querying + room_id (str): Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id (str): Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) @@ -136,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -157,12 +146,102 @@ class EndToEndRoomKeyStore(SQLBaseStore): room_entry["sessions"][row["session_id"]] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], - "is_verified": row["is_verified"], + # is_verified must be returned to the client as a boolean + "is_verified": bool(row["is_verified"]), "session_data": json.loads(row["session_data"]), } return sessions + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + room_keys (dict[str, dict[str, iterable[str]]]): a map from + room ID -> {"session": [session ids]} indicating the session IDs + that we want to query + + Returns: + Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + """ + + return self.db.runInteraction( + "get_e2e_room_keys_multi", + self._get_e2e_room_keys_multi_txn, + user_id, + version, + room_keys, + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not room_keys: + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not sessions: + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + # check if we're actually querying something + if not where_clauses: + return {} + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % ( + " OR ".join(where_clauses) + ) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": json.loads(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + """ + + return self.db.simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + @trace @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): @@ -189,7 +268,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self._simple_delete( + yield self.db.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -220,6 +299,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str) algorithm(str) auth_data(object): opaque dict supplied by the client + etag(int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn): @@ -233,17 +313,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), + retcols=("version", "algorithm", "auth_data", "etag"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) + if result["etag"] is None: + result["etag"] = 0 return result - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @@ -271,7 +353,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -284,26 +366,38 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version - return self.runInteraction( + return self.db.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @trace - def update_e2e_room_keys_version(self, user_id, version, info): + def update_e2e_room_keys_version( + self, user_id, version, info=None, version_etag=None + ): """Update a given backup version Args: user_id(str): the user whose backup version we're updating version(str): the version ID of the backup version we're updating - info(dict): the new backup version info to store + info (dict): the new backup version info to store. If None, then + the backup version info is not updated + version_etag (Optional[int]): etag of the keys in the backup. If + None, then the etag is not updated """ + updatevalues = {} - return self._simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, - desc="update_e2e_room_keys_version", - ) + if info is not None and "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if version_etag is not None: + updatevalues["etag"] = version_etag + + if updatevalues: + return self.db.simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues=updatevalues, + desc="update_e2e_room_keys_version", + ) @trace def delete_e2e_room_keys_version(self, user_id, version=None): @@ -322,16 +416,24 @@ class EndToEndRoomKeyStore(SQLBaseStore): def _delete_e2e_room_keys_version_txn(txn): if version is None: this_version = self._get_current_version(txn, user_id) + if this_version is None: + raise StoreError(404, "No current backup version") else: this_version = version - return self._simple_update_one_txn( + self.db.simple_delete_txn( + txn, + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": this_version}, + ) + + return self.db.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py new file mode 100644 index 0000000000..20698bfd16 --- /dev/null +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -0,0 +1,721 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List + +from six import iteritems + +from canonicaljson import encode_canonical_json, json + +from twisted.enterprise.adbapi import Connection +from twisted.internet import defer + +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import make_in_list_sql_clause +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter + + +class EndToEndKeyWorkerStore(SQLBaseStore): + @trace + @defer.inlineCallbacks + def get_e2e_device_keys( + self, query_list, include_all_devices=False, include_deleted_devices=False + ): + """Fetch a list of device keys. + Args: + query_list(list): List of pairs of user_ids and device_ids. + include_all_devices (bool): whether to include entries for devices + that don't have device keys + include_deleted_devices (bool): whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. + Returns: + Dict mapping from user-id to dict mapping from device_id to + key data. The key data will be a dict in the same format as the + DeviceKeys type returned by POST /_matrix/client/r0/keys/query. + """ + set_tag("query_list", query_list) + if not query_list: + return {} + + results = yield self.db.runInteraction( + "get_e2e_device_keys", + self._get_e2e_device_keys_txn, + query_list, + include_all_devices, + include_deleted_devices, + ) + + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section + rv = {} + for user_id, device_keys in iteritems(results): + rv[user_id] = {} + for device_id, device_info in iteritems(device_keys): + r = db_to_json(device_info.pop("key_json")) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + if "signatures" in device_info: + for sig_user_id, sigs in device_info["signatures"].items(): + r.setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + rv[user_id][device_id] = r + + return rv + + @trace + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, include_deleted_devices=False + ): + set_tag("include_all_devices", include_all_devices) + set_tag("include_deleted_devices", include_deleted_devices) + + query_clauses = [] + query_params = [] + signature_query_clauses = [] + signature_query_params = [] + + if include_all_devices is False: + include_deleted_devices = False + + if include_deleted_devices: + deleted_devices = set(query_list) + + for (user_id, device_id) in query_list: + query_clause = "user_id = ?" + query_params.append(user_id) + signature_query_clause = "target_user_id = ?" + signature_query_params.append(user_id) + + if device_id is not None: + query_clause += " AND device_id = ?" + query_params.append(device_id) + signature_query_clause += " AND target_device_id = ?" + signature_query_params.append(device_id) + + signature_query_clause += " AND user_id = ?" + signature_query_params.append(user_id) + + query_clauses.append(query_clause) + signature_query_clauses.append(signature_query_clause) + + sql = ( + "SELECT user_id, device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" + " WHERE %s AND NOT d.hidden" + ) % ( + "LEFT" if include_all_devices else "INNER", + " OR ".join("(" + q + ")" for q in query_clauses), + ) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + result = {} + for row in rows: + if include_deleted_devices: + deleted_devices.remove((row["user_id"], row["device_id"])) + result.setdefault(row["user_id"], {})[row["device_id"]] = row + + if include_deleted_devices: + for user_id, device_id in deleted_devices: + result.setdefault(user_id, {})[device_id] = None + + # get signatures on the device + signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) + + txn.execute(signature_sql, signature_query_params) + rows = self.db.cursor_to_dict(txn) + + # add each cross-signing signature to the correct device in the result dict. + for row in rows: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + signature = row["signature"] + + target_user_result = result.get(target_user_id) + if not target_user_result: + continue + + target_device_result = target_user_result.get(target_device_id) + if not target_device_result: + # note that target_device_result will be None for deleted devices. + continue + + target_device_signatures = target_device_result.setdefault("signatures", {}) + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature + + log_kv(result) + return result + + @defer.inlineCallbacks + def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + """Retrieve a number of one-time keys for a user + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + key_ids(list[str]): list of key ids (excluding algorithm) to + retrieve + + Returns: + deferred resolving to Dict[(str, str), str]: map from (algorithm, + key_id) to json string for key + """ + + rows = yield self.db.simple_select_many_batch( + table="e2e_one_time_keys_json", + column="key_id", + iterable=key_ids, + retcols=("algorithm", "key_id", "key_json"), + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="add_e2e_one_time_keys_check", + ) + result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} + log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) + return result + + @defer.inlineCallbacks + def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + """Insert some new one time keys for a device. Errors if any of the + keys already exist. + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + time_now(long): insertion time to record (ms since epoch) + new_keys(iterable[(str, str, str)]: keys to add - each a tuple of + (algorithm, key_id, key json) + """ + + def _add_e2e_one_time_keys(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("new_keys", new_keys) + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self.db.simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + values=[ + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + "ts_added_ms": time_now, + "key_json": json_bytes, + } + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + yield self.db.runInteraction( + "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys + ) + + @cached(max_entries=10000) + def count_e2e_one_time_keys(self, user_id, device_id): + """ Count the number of one time keys the server has for a device + Returns: + Dict mapping from algorithm to number of keys for that algorithm. + """ + + def _count_e2e_one_time_keys(txn): + sql = ( + "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ?" + " GROUP BY algorithm" + ) + txn.execute(sql, (user_id, device_id)) + result = {} + for algorithm, key_count in txn: + result[algorithm] = key_count + return result + + return self.db.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) + + @defer.inlineCallbacks + def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the self-signing key will be included in the result + + Returns: + dict of the key data or None if not found + """ + res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + user_keys = res.get(user_id) + if not user_keys: + return None + return user_keys.get(key_type) + + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + for user_chunk in batch_iter(user_ids, 100): + clause, params = make_in_list_sql_clause( + txn.database_engine, "k.user_id", user_chunk + ) + sql = ( + """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE + """ + + clause + ) + + txn.execute(sql, params) + rows = self.db.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = json.loads(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + for batch in batch_iter(devices.keys(), size=100): + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for _ in batch + ) + ) + query_params = [from_user_id] + for item in batch: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + + def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit): + """Return a list of changes from the user signature stream to notify remotes. + Note that the user signature stream represents when a user signs their + device with their user-signing key, which is not published to other + users or servers, so no `destination` is needed in the returned + list. However, this is needed to poke workers. + + Args: + from_key (int): the stream ID to start at (exclusive) + to_key (int): the stream ID to end at (inclusive) + + Returns: + Deferred[list[(int,str)]] a list of `(stream_id, user_id)` + """ + sql = """ + SELECT stream_id, from_user_id AS user_id + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + return self.db.execute( + "get_all_user_signature_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, + ) + + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + + def _set_e2e_device_keys_txn(txn): + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", device_keys) + + old_key_json = self.db.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False + + self.db.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, + ) + log_kv({"message": "Device keys stored."}) + return True + + return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + + def claim_e2e_one_time_keys(self, query_list): + """Take a list of one time keys out of the database""" + + @trace + def _claim_e2e_one_time_keys(txn): + sql = ( + "SELECT key_id, key_json FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) + result = {} + delete = [] + for user_id, device_id, algorithm in query_list: + user_result = result.setdefault(user_id, {}) + device_result = user_result.setdefault(device_id, {}) + txn.execute(sql, (user_id, device_id, algorithm)) + for key_id, key_json in txn: + device_result[algorithm + ":" + key_id] = key_json + delete.append((user_id, device_id, algorithm, key_id)) + sql = ( + "DELETE FROM e2e_one_time_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " AND key_id = ?" + ) + for user_id, device_id, algorithm, key_id in delete: + log_kv( + { + "message": "Executing claim e2e_one_time_keys transaction on database." + } + ) + txn.execute(sql, (user_id, device_id, algorithm, key_id)) + log_kv({"message": "finished executing and invalidating cache"}) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + return result + + return self.db.runInteraction( + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys + ) + + def delete_e2e_keys_by_device(self, user_id, device_id): + def delete_e2e_keys_by_device_txn(txn): + log_kv( + { + "message": "Deleting keys for device", + "device_id": device_id, + "user_id": user_id, + } + ) + self.db.simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db.simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + return self.db.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn + ) + + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user to set the signing key for + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + key (dict): the key data + """ + # the 'key' dict will look something like: + # { + # "user_id": "@alice:example.com", + # "usage": ["self_signing"], + # "keys": { + # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", + # }, + # "signatures": { + # "@alice:example.com": { + # "ed25519:base64+master+public+key": "base64+signature" + # } + # } + # } + # The "keys" property must only have one entry, which will be the public + # key, so we just grab the first value in there + pubkey = next(iter(key["keys"].values())) + + # The cross-signing keys need to occupy the same namespace as devices, + # since signatures are identified by device ID. So add an entry to the + # device table to make sure that we don't have a collision with device + # IDs. + # We only need to do this for local users, since remote servers should be + # responsible for checking this for their own users. + if self.hs.is_mine_id(user_id): + self.db.simple_insert_txn( + txn, + "devices", + values={ + "user_id": user_id, + "device_id": pubkey, + "display_name": key_type + " signing key", + "hidden": True, + }, + ) + + # and finally, store the key itself + with self._cross_signing_id_gen.get_next() as stream_id: + self.db.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json.dumps(key), + "stream_id": stream_id, + }, + ) + + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + + def set_e2e_cross_signing_key(self, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + user_id (str): the user to set the user-signing key for + key_type (str): the type of cross-signing key to set + key (dict): the key data + """ + return self.db.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + ) + + def store_e2e_cross_signing_signatures(self, user_id, signatures): + """Stores cross-signing signatures. + + Args: + user_id (str): the user who made the signatures + signatures (iterable[SignatureListItem]): signatures to add + """ + return self.db.simple_insert_many( + "e2e_cross_signing_signatures", + [ + { + "user_id": user_id, + "key_id": item.signing_key_id, + "target_user_id": item.target_user_id, + "target_device_id": item.target_device_id, + "signature": item.signature, + } + for item in signatures + ], + "add_e2e_signing_key", + ) diff --git a/synapse/storage/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 4f500d893e..24ce8c4330 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -12,22 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging -import random +from typing import Dict, List, Optional, Set, Tuple -from six.moves import range from six.moves.queue import Empty, PriorityQueue -from unpaddedbase64 import encode_base64 - from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.signatures import SignatureWorkerStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) @@ -47,37 +47,55 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_ids, include_given=include_given ).addCallback(self.get_events_as_list) - def get_auth_chain_ids(self, event_ids, include_given=False): + def get_auth_chain_ids( + self, + event_ids: List[str], + include_given: bool = False, + ignore_events: Optional[Set[str]] = None, + ): """Get auth events for given event_ids. The events *must* be state events. Args: - event_ids (list): state events - include_given (bool): include the given events in result + event_ids: state events + include_given: include the given events in result + ignore_events: Set of events to exclude from the returned auth + chain. This is useful if the caller will just discard the + given events anyway, and saves us from figuring out their auth + chains if not required. Returns: list of event_ids """ - return self.runInteraction( - "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given + return self.db.runInteraction( + "get_auth_chain_ids", + self._get_auth_chain_ids_txn, + event_ids, + include_given, + ignore_events, ) - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): + if ignore_events is None: + ignore_events = set() + if include_given: results = set(event_ids) else: results = set() - base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" + base_sql = "SELECT auth_id FROM event_auth WHERE " front = set(event_ids) while front: new_front = set() - front_list = list(front) - chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] - for chunk in chunks: - txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk) - new_front.update([r[0] for r in txn]) + for chunk in batch_iter(front, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", chunk + ) + txn.execute(base_sql + clause, args) + new_front.update(r[0] for r in txn) + new_front -= ignore_events new_front -= results front = new_front @@ -85,13 +103,170 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. + + Returns: + Deferred[Set[str]] + """ + + return self.db.runInteraction( + "get_auth_chain_difference", + self._get_auth_chain_difference_txn, + state_sets, + ) + + def _get_auth_chain_difference_txn( + self, txn, state_sets: List[Set[str]] + ) -> Set[str]: + + # Algorithm Description + # ~~~~~~~~~~~~~~~~~~~~~ + # + # The idea here is to basically walk the auth graph of each state set in + # tandem, keeping track of which auth events are reachable by each state + # set. If we reach an auth event we've already visited (via a different + # state set) then we mark that auth event and all ancestors as reachable + # by the state set. This requires that we keep track of the auth chains + # in memory. + # + # Doing it in a such a way means that we can stop early if all auth + # events we're currently walking are reachable by all state sets. + # + # *Note*: We can't stop walking an event's auth chain if it is reachable + # by all state sets. This is because other auth chains we're walking + # might be reachable only via the original auth chain. For example, + # given the following auth chain: + # + # A -> C -> D -> E + # / / + # B -´---------´ + # + # and state sets {A} and {B} then walking the auth chains of A and B + # would immediately show that C is reachable by both. However, if we + # stopped at C then we'd only reach E via the auth chain of B and so E + # would errornously get included in the returned difference. + # + # The other thing that we do is limit the number of auth chains we walk + # at once, due to practical limits (i.e. we can only query the database + # with a limited set of parameters). We pick the auth chains we walk + # each iteration based on their depth, in the hope that events with a + # lower depth are likely reachable by those with higher depths. + # + # We could use any ordering that we believe would give a rough + # topological ordering, e.g. origin server timestamp. If the ordering + # chosen is not topological then the algorithm still produces the right + # result, but perhaps a bit more inefficiently. This is why it is safe + # to use "depth" here. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Dict from events in auth chains to which sets *cannot* reach them. + # I.e. if the set is empty then all sets can reach the event. + event_to_missing_sets = { + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} + for event_id in initial_events + } + + # The sorted list of events whose auth chains we should walk. + search = [] # type: List[Tuple[int, str]] + + # We need to get the depth of the initial events for sorting purposes. + sql = """ + SELECT depth, event_id FROM events + WHERE %s + """ + # the list can be huge, so let's avoid looking them all up in one massive + # query. + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + # I think building a temporary list with fetchall is more efficient than + # just `search.extend(txn)`, but this is unconfirmed + search.extend(txn.fetchall()) + + # sort by depth + search.sort() + + # Map from event to its auth events + event_to_auth_events = {} # type: Dict[str, Set[str]] + + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ + + while search: + # Check whether all our current walks are reachable by all state + # sets. If so we can bail. + if all(not event_to_missing_sets[eid] for _, eid in search): + break + + # Fetch the auth events and their depths of the N last events we're + # currently walking + search, chunk = search[:-100], search[-100:] + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] + ) + txn.execute(base_sql + clause, args) + + for event_id, auth_event_id, auth_event_depth in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) + + sets = event_to_missing_sets.get(auth_event_id) + if sets is None: + # First time we're seeing this event, so we add it to the + # queue of things to fetch. + search.append((auth_event_depth, auth_event_id)) + + # Assume that this event is unreachable from any of the + # state sets until proven otherwise + sets = event_to_missing_sets[auth_event_id] = set( + range(len(state_sets)) + ) + else: + # We've previously seen this event, so look up its auth + # events and recursively mark all ancestors as reachable + # by the current event's state set. + a_ids = event_to_auth_events.get(auth_event_id) + while a_ids: + new_aids = set() + for a_id in a_ids: + event_to_missing_sets[a_id].intersection_update( + event_to_missing_sets[event_id] + ) + + b = event_to_auth_events.get(a_id) + if b: + new_aids.update(b) + + a_ids = new_aids + + # Mark that the auth event is reachable by the approriate sets. + sets.intersection_update(event_to_missing_sets[event_id]) + + search.sort() + + # Return all events where not all sets can reach them. + return {eid for eid, n in event_to_missing_sets.items() if n} + def get_oldest_events_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id ) def get_oldest_events_with_depth_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -122,7 +297,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns Deferred[int] """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -136,15 +311,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self._simple_select_onecol_txn( + return self.db.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", ) - @defer.inlineCallbacks - def get_prev_events_for_room(self, room_id): + def get_prev_events_for_room(self, room_id: str): """ Gets a subset of the current forward extremities in the given room. @@ -155,47 +329,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas room_id (str): room_id Returns: - Deferred[list[(str, dict[str, str], int)]] - for each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. + Deferred[List[str]]: the event ids of the forward extremites + """ - res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) - if len(res) > 10: - # Sort by reverse depth, so we point to the most recent. - res.sort(key=lambda a: -a[2]) - # we use half of the limit for the actual most recent events, and - # the other half to randomly point to some of the older events, to - # make sure that we don't completely ignore the older events. - res = res[0:5] + random.sample(res[5:], 5) + return self.db.runInteraction( + "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id + ) - return res + def _get_prev_events_for_room_txn(self, txn, room_id: str): + # we just use the 10 newest events. Older events will become + # prev_events of future events. - def get_latest_event_ids_and_hashes_in_room(self, room_id): + sql = """ + SELECT e.event_id FROM event_forward_extremities AS f + INNER JOIN events AS e USING (event_id) + WHERE f.room_id = ? + ORDER BY e.depth DESC + LIMIT 10 """ - Gets the current forward extremities in the given room - Args: - room_id (str): room_id + txn.execute(sql, (room_id,)) - Returns: - Deferred[list[(str, dict[str, str], int)]] - for each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. - """ + return [row[0] for row in txn] - return self.runInteraction( - "get_latest_event_ids_and_hashes_in_room", - self._get_latest_event_ids_and_hashes_in_room, - room_id, - ) - - def get_rooms_with_many_extremities(self, min_count, limit): + def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): """Get the top rooms with at least N extremities. Args: min_count (int): The minimum number of extremities limit (int): The maximum number of rooms to return. + room_id_filter (iterable[str]): room_ids to exclude from the results Returns: Deferred[list]: At most `limit` room IDs that have at least @@ -203,60 +367,49 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ def _get_rooms_with_many_extremities_txn(txn): + where_clause = "1=1" + if room_id_filter: + where_clause = "room_id NOT IN (%s)" % ( + ",".join("?" for _ in room_id_filter), + ) + sql = """ SELECT room_id FROM event_forward_extremities + WHERE %s GROUP BY room_id HAVING count(*) > ? ORDER BY count(*) DESC LIMIT ? - """ + """ % ( + where_clause, + ) - txn.execute(sql, (min_count, limit)) + query_args = list(itertools.chain(room_id_filter, [min_count, limit])) + txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", desc="get_latest_event_ids_in_room", ) - def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): - sql = ( - "SELECT e.event_id, e.depth FROM events as e " - "INNER JOIN event_forward_extremities as f " - "ON e.event_id = f.event_id " - "AND e.room_id = f.room_id " - "WHERE f.room_id = ?" - ) - - txn.execute(sql, (room_id,)) - - results = [] - for event_id, depth in txn.fetchall(): - hashes = self._get_event_reference_hashes_txn(txn, event_id) - prev_hashes = { - k: encode_base64(v) for k, v in hashes.items() if k == "sha256" - } - results.append((event_id, prev_hashes, depth)) - - return results - def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ - return self.runInteraction( + return self.db.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self._simple_select_one_onecol_txn( + min_depth = self.db.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -322,7 +475,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) @@ -337,7 +490,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas limit (int) """ return ( - self.runInteraction( + self.db.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, @@ -349,9 +502,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug( - "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit - ) + logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) event_results = set() @@ -370,7 +521,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self._simple_select_one_onecol_txn( + depth = self.db.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -402,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.runInteraction( + ids = yield self.db.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -432,7 +583,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas query, (room_id, event_id, False, limit - len(event_results)) ) - new_results = set(t[0] for t in txn) - seen_events + new_results = {t[0] for t in txn} - seen_events new_front |= new_results seen_events |= new_results @@ -455,7 +606,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: Deferred[list[str]] """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -478,10 +629,10 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -489,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore): self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - if min_depth and depth >= min_depth: - return - - self._simple_upsert_txn( - txn, - table="room_depth", - keyvalues={"room_id": room_id}, - values={"min_depth": depth}, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id in ev.prev_event_ids() - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany( - query, - [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events - for e_id in ev.prev_event_ids() - if not ev.internal_metadata.is_outlier() - ], - ) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) - for ev in events - if not ev.internal_metadata.is_outlier() - ], - ) - def _delete_old_forward_extrem_cache(self): def _delete_old_forward_extrem_cache_txn(txn): # Delete entries older than a month, while making sure we don't delete @@ -591,13 +659,13 @@ class EventFederationStore(EventFederationWorkerStore): return run_as_background_process( "delete_old_forward_extrem_cache", - self.runInteraction, + self.db.runInteraction, "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, ) def clean_room_for_join(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) @@ -641,17 +709,17 @@ class EventFederationStore(EventFederationWorkerStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) return min_stream_id >= target_min_stream_id - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) return batch_size diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index b01a12528f..8ad7a306f8 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -93,7 +94,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, @@ -177,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.runInteraction("get_push_action_users_in_range", f) + ret = yield self.db.runInteraction("get_push_action_users_in_range", f) return ret @defer.inlineCallbacks @@ -230,7 +231,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -259,7 +260,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -332,7 +333,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -361,7 +362,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -411,7 +412,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) @@ -446,7 +447,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _add_push_actions_to_staging_txn(txn): - # We don't use _simple_insert_many here to avoid the overhead + # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. sql = """ @@ -463,7 +464,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.runInteraction( + return self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) @@ -477,7 +478,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = yield self._simple_delete( + res = yield self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -494,7 +495,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _find_stream_orderings_for_times(self): return run_as_background_process( "event_push_action_stream_orderings", - self.runInteraction, + self.db.runInteraction, "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) @@ -530,7 +531,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): Deferred[int]: stream ordering of the first event received on/after the timestamp """ - return self.runInteraction( + return self.db.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -612,21 +613,38 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end + @defer.inlineCallbacks + def get_time_of_last_push_action_before(self, stream_ordering): + def f(txn): + sql = ( + "SELECT e.received_ts" + " FROM event_push_actions AS ep" + " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" + " WHERE ep.stream_ordering > ?" + " ORDER BY ep.stream_ordering ASC" + " LIMIT 1" + ) + txn.execute(sql, (stream_ordering,)) + return txn.fetchone() + + result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + return result[0] if result else None + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", @@ -639,69 +657,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): - """Handles moving push actions from staging table to main - event_push_actions table for all events in `events_and_contexts`. - - Also ensures that all events in `all_events_and_contexts` are removed - from the push action staging area. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - """ - - sql = """ - INSERT INTO event_push_actions ( - room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight - ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight - FROM event_push_actions_staging - WHERE event_id = ? - """ - - if events_and_contexts: - txn.executemany( - sql, - ( - ( - event.room_id, - event.internal_metadata.stream_ordering, - event.depth, - event.event_id, - ) - for event, _ in events_and_contexts - ), - ) - - for event, _ in events_and_contexts: - user_ids = self._simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={"event_id": event.event_id}, - retcol="user_id", - ) - - for uid in user_ids: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid), - ) - - # Now we delete the staging area for *all* events that were being - # persisted. - txn.executemany( - "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ((event.event_id,) for event, _ in all_events_and_contexts), - ) - @defer.inlineCallbacks def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False @@ -732,49 +687,23 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - push_actions = yield self.runInteraction("get_push_actions_for_user", f) + push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): - sql = ( - "SELECT e.received_ts" - " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ?" - " ORDER BY ep.stream_ordering ASC" - " LIMIT 1" - ) - txn.execute(sql, (stream_ordering,)) - return txn.fetchone() - - result = yield self.runInteraction("get_time_of_last_push_action_before", f) - return result[0] if result else None - - @defer.inlineCallbacks def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) - return result[0] or 0 - - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id,), - ) - txn.execute( - "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", - (room_id, event_id), + result = yield self.db.runInteraction( + "get_latest_push_action_stream_ordering", f ) + return result[0] or 0 def _remove_old_push_actions_before_txn( self, txn, room_id, user_id, stream_ordering @@ -835,7 +764,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.runInteraction( + caught_up = yield self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: @@ -849,7 +778,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -868,7 +797,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) stream_row = txn.fetchone() if stream_row: - offset_stream_ordering, = stream_row + (offset_stream_ordering,) = stream_row rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) @@ -885,7 +814,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -917,7 +846,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py new file mode 100644 index 0000000000..a6572571b4 --- /dev/null +++ b/synapse/storage/data_stores/main/events.py @@ -0,0 +1,1620 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections import OrderedDict, namedtuple +from functools import wraps +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple + +from six import integer_types, iteritems, text_type +from six.moves import range + +import attr +from canonicaljson import json +from prometheus_client import Counter + +from twisted.internet import defer + +import synapse.metrics +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) +from synapse.api.room_versions import RoomVersions +from synapse.crypto.event_signing import compute_event_reference_hash +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.logging.utils import log_function +from synapse.storage._base import make_in_list_sql_clause +from synapse.storage.data_stores.main.search import SearchEntry +from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import StateMap, get_domain_from_id +from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.iterutils import batch_iter + +if TYPE_CHECKING: + from synapse.storage.data_stores.main import DataStore + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + +persist_event_counter = Counter("synapse_storage_events_persisted_events", "") +event_counter = Counter( + "synapse_storage_events_persisted_events_sep", + "", + ["type", "origin_type", "origin_entity"], +) + + +def encode_json(json_object): + """ + Encode a Python object as JSON and return it in a Unicode string. + """ + out = frozendict_json_encoder.encode(json_object) + if isinstance(out, bytes): + out = out.decode("utf8") + return out + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +def _retry_on_integrity_error(func): + """Wraps a database function so that it gets retried on IntegrityError, + with `delete_existing=True` passed in. + + Args: + func: function that returns a Deferred and accepts a `delete_existing` arg + """ + + @wraps(func) + @defer.inlineCallbacks + def f(self, *args, **kwargs): + try: + res = yield func(self, *args, delete_existing=False, **kwargs) + except self.database_engine.module.IntegrityError: + logger.exception("IntegrityError, retrying.") + res = yield func(self, *args, delete_existing=True, **kwargs) + return res + + return f + + +@attr.s(slots=True) +class DeltaState: + """Deltas to use to update the `current_state_events` table. + + Attributes: + to_delete: List of type/state_keys to delete from current state + to_insert: Map of state to upsert into current state + no_longer_in_room: The server is not longer in the room, so the room + should e.g. be removed from `current_state_events` table. + """ + + to_delete = attr.ib(type=List[Tuple[str, str]]) + to_insert = attr.ib(type=StateMap[str]) + no_longer_in_room = attr.ib(type=bool, default=False) + + +class PersistEventsStore: + """Contains all the functions for writing events to the database. + + Should only be instantiated on one process (when using a worker mode setup). + + Note: This is not part of the `DataStore` mixin. + """ + + def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): + self.hs = hs + self.db = db + self.store = main_data_store + self.database_engine = db.engine + self._clock = hs.get_clock() + + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self.is_mine_id = hs.is_mine_id + + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator + + # This should only exist on instances that are configured to write + assert ( + hs.config.worker.writers.events == hs.get_instance_name() + ), "Can only instantiate EventsStore on master" + + @_retry_on_integrity_error + @defer.inlineCallbacks + def _persist_events_and_state_updates( + self, + events_and_contexts: List[Tuple[EventBase, EventContext]], + current_state_for_room: Dict[str, StateMap[str]], + state_delta_for_room: Dict[str, DeltaState], + new_forward_extremeties: Dict[str, List[str]], + backfilled: bool = False, + delete_existing: bool = False, + ): + """Persist a set of events alongside updates to the current state and + forward extremities tables. + + Args: + events_and_contexts: + current_state_for_room: Map from room_id to the current state of + the room based on forward extremities + state_delta_for_room: Map from room_id to the delta to apply to + room state + new_forward_extremities: Map from room_id to list of event IDs + that are the new forward extremities of the room. + backfilled + delete_existing + + Returns: + Deferred: resolves when the events have been persisted + """ + + # We want to calculate the stream orderings as late as possible, as + # we only notify after all events with a lesser stream ordering have + # been persisted. I.e. if we spend 10s inside the with block then + # that will delay all subsequent events from being notified about. + # Hence why we do it down here rather than wrapping the entire + # function. + # + # Its safe to do this after calculating the state deltas etc as we + # only need to protect the *persistence* of the events. This is to + # ensure that queries of the form "fetch events since X" don't + # return events and stream positions after events that are still in + # flight, as otherwise subsequent requests "fetch event since Y" + # will not return those events. + # + # Note: Multiple instances of this function cannot be in flight at + # the same time for the same room. + if backfilled: + stream_ordering_manager = self._backfill_id_gen.get_next_mult( + len(events_and_contexts) + ) + else: + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) + ) + + with stream_ordering_manager as stream_orderings: + for (event, context), stream in zip(events_and_contexts, stream_orderings): + event.internal_metadata.stream_ordering = stream + + yield self.db.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + delete_existing=delete_existing, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + ) + persist_event_counter.inc(len(events_and_contexts)) + + if not backfilled: + # backfilled events have negative stream orderings, so we don't + # want to set the event_persisted_position to that. + synapse.metrics.event_persisted_position.set( + events_and_contexts[-1][0].internal_metadata.stream_ordering + ) + + for event, context in events_and_contexts: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) + + event_counter.labels(event.type, origin_type, origin_entity).inc() + + for room_id, new_state in iteritems(current_state_for_room): + self.store.get_current_state_ids.prefill((room_id,), new_state) + + for room_id, latest_event_ids in iteritems(new_forward_extremeties): + self.store.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) + + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier/rejected) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events_which_are_prevs_txn(txn, batch): + sql = """ + SELECT prev_event_id, internal_metadata + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + NOT events.outlier + AND rejections.event_id IS NULL + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", batch + ) + + txn.execute(sql + clause, args) + results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) + + for chunk in batch_iter(event_ids, 100): + yield self.db.runInteraction( + "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk + ) + + return results + + @defer.inlineCallbacks + def _get_prevs_before_rejected(self, event_ids): + """Get soft-failed ancestors to remove from the extremities. + + Given a set of events, find all those that have been soft-failed or + rejected. Returns those soft failed/rejected events and their prev + events (whether soft-failed/rejected or not), and recurses up the + prev-event graph until it finds no more soft-failed/rejected events. + + This is used to find extremities that are ancestors of new events, but + are separated by soft failed events. + + Args: + event_ids (Iterable[str]): Events to find prev events for. Note + that these must have already been persisted. + + Returns: + Deferred[set[str]] + """ + + # The set of event_ids to return. This includes all soft-failed events + # and their prev events. + existing_prevs = set() + + def _get_prevs_before_rejected_txn(txn, batch): + to_recursively_check = batch + + while to_recursively_check: + sql = """ + SELECT + event_id, prev_event_id, internal_metadata, + rejections.event_id IS NOT NULL + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) + WHERE + NOT events.outlier + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", to_recursively_check + ) + + txn.execute(sql + clause, args) + to_recursively_check = [] + + for event_id, prev_event_id, metadata, rejected in txn: + if prev_event_id in existing_prevs: + continue + + soft_failed = json.loads(metadata).get("soft_failed") + if soft_failed or rejected: + to_recursively_check.append(prev_event_id) + existing_prevs.add(prev_event_id) + + for chunk in batch_iter(event_ids, 100): + yield self.db.runInteraction( + "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk + ) + + return existing_prevs + + @log_function + def _persist_events_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + delete_existing: bool = False, + state_delta_for_room: Dict[str, DeltaState] = {}, + new_forward_extremeties: Dict[str, List[str]] = {}, + ): + """Insert some number of room events into the necessary database tables. + + Rejected events are only inserted into the events table, the events_json table, + and the rejections table. Things reading from those table will need to check + whether the event was rejected. + + Args: + txn + events_and_contexts: events to persist + backfilled: True if the events were backfilled + delete_existing True to purge existing table rows for the events + from the database. This is useful when retrying due to + IntegrityError. + state_delta_for_room: The current-state delta for each room. + new_forward_extremetie: The new forward extremities for each room. + For each room, a list of the event ids which are the forward + extremities. + + """ + all_events_and_contexts = events_and_contexts + + min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremeties, + max_stream_order=max_stream_order, + ) + + # Ensure that we don't have the same event twice. + events_and_contexts = self._filter_events_and_contexts_for_duplicates( + events_and_contexts + ) + + self._update_room_depths_txn( + txn, events_and_contexts=events_and_contexts, backfilled=backfilled + ) + + # _update_outliers_txn filters out any events which have already been + # persisted, and returns the filtered list. + events_and_contexts = self._update_outliers_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only events that we haven't + # seen before. + + if delete_existing: + # For paranoia reasons, we go and delete all the existing entries + # for these events so we can reinsert them. + # This gets around any problems with some tables already having + # entries. + self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) + + self._store_event_txn(txn, events_and_contexts=events_and_contexts) + + # Insert into event_to_state_groups. + self._store_event_state_mappings_txn(txn, events_and_contexts) + + # We want to store event_auth mappings for rejected events, as they're + # used in state res v2. + # This is only necessary if the rejected event appears in an accepted + # event's auth chain, but its easier for now just to store them (and + # it doesn't take much storage compared to storing the entire event + # anyway). + self.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + } + for event, _ in events_and_contexts + for auth_id in event.auth_event_ids() + if event.is_state() + ], + ) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, events_and_contexts=events_and_contexts + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + # We call this last as it assumes we've inserted the events into + # room_memberships, where applicable. + self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + + def _update_current_state_txn( + self, + txn: LoggingTransaction, + state_delta_by_room: Dict[str, DeltaState], + stream_id: int, + ): + for room_id, delta_state in iteritems(state_delta_by_room): + to_delete = delta_state.to_delete + to_insert = delta_state.to_insert + + if delta_state.no_longer_in_room: + # Server is no longer in the room so we delete the room from + # current_state_events, being careful we've already updated the + # rooms.room_version column (which gets populated in a + # background task). + self._upsert_room_version_txn(txn, room_id) + + # Before deleting we populate the current_state_delta_stream + # so that async background tasks get told what happened. + sql = """ + INSERT INTO current_state_delta_stream + (stream_id, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, room_id, type, state_key, null, event_id + FROM current_state_events + WHERE room_id = ? + """ + txn.execute(sql, (stream_id, room_id)) + + self.db.simple_delete_txn( + txn, table="current_state_events", keyvalues={"room_id": room_id}, + ) + else: + # We're still in the room, so we update the current state as normal. + + # First we add entries to the current_state_delta_stream. We + # do this before updating the current_state_events table so + # that we can use it to calculate the `prev_event_id`. (This + # allows us to not have to pull out the existing state + # unnecessarily). + # + # The stream_id for the update is chosen to be the minimum of the stream_ids + # for the batch of the events that we are persisting; that means we do not + # end up in a situation where workers see events before the + # current_state_delta updates. + # + sql = """ + INSERT INTO current_state_delta_stream + (stream_id, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, ?, ?, ?, ( + SELECT event_id FROM current_state_events + WHERE room_id = ? AND type = ? AND state_key = ? + ) + """ + txn.executemany( + sql, + ( + ( + stream_id, + room_id, + etype, + state_key, + to_insert.get((etype, state_key)), + room_id, + etype, + state_key, + ) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + # Now we actually update the current_state_events table + + txn.executemany( + "DELETE FROM current_state_events" + " WHERE room_id = ? AND type = ? AND state_key = ?", + ( + (room_id, etype, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + ), + ) + + # We include the membership in the current state table, hence we do + # a lookup when we insert. This assumes that all events have already + # been inserted into room_memberships. + txn.executemany( + """INSERT INTO current_state_events + (room_id, type, state_key, event_id, membership) + VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[0], key[1], ev_id, ev_id) + for key, ev_id in iteritems(to_insert) + ], + ) + + # We now update `local_current_membership`. We do this regardless + # of whether we're still in the room or not to handle the case where + # e.g. we just got banned (where we need to record that fact here). + + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.executemany( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) + + if to_insert: + txn.executemany( + """INSERT INTO local_current_membership + (room_id, user_id, event_id, membership) + VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[1], ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], + ) + + txn.call_after( + self.store._curr_state_delta_stream_cache.entity_has_changed, + room_id, + stream_id, + ) + + # Invalidate the various caches + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } + + for member in members_changed: + txn.call_after( + self.store.get_rooms_for_user_with_stream_ordering.invalidate, + (member,), + ) + + self.store._invalidate_state_caches_and_stream( + txn, room_id, members_changed + ) + + def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): + """Update the room version in the database based off current state + events. + + This is used when we're about to delete current state and we want to + ensure that the `rooms.room_version` column is up to date. + """ + + sql = """ + SELECT json FROM event_json + INNER JOIN current_state_events USING (room_id, event_id) + WHERE room_id = ? AND type = ? AND state_key = ? + """ + txn.execute(sql, (room_id, EventTypes.Create, "")) + row = txn.fetchone() + if row: + event_json = json.loads(row[0]) + content = event_json.get("content", {}) + creator = content.get("creator") + room_version_id = content.get("room_version", RoomVersions.V1.identifier) + + self.db.simple_upsert_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version_id}, + insertion_values={"is_public": False, "creator": creator}, + ) + + def _update_forward_extremities_txn( + self, txn, new_forward_extremities, max_stream_order + ): + for room_id, new_extrem in iteritems(new_forward_extremities): + self.db.simple_delete_txn( + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} + ) + txn.call_after( + self.store.get_latest_event_ids_in_room.invalidate, (room_id,) + ) + + self.db.simple_insert_many_txn( + txn, + table="event_forward_extremities", + values=[ + {"event_id": ev_id, "room_id": room_id} + for room_id, new_extrem in iteritems(new_forward_extremities) + for ev_id in new_extrem + ], + ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + self.db.simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_order, + } + for room_id, new_extrem in iteritems(new_forward_extremities) + for event_id in new_extrem + ], + ) + + @classmethod + def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + """Ensure that we don't have the same event twice. + + Pick the earliest non-outlier if there is one, else the earliest one. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + Returns: + list[(EventBase, EventContext)]: filtered list + """ + new_events_and_contexts = OrderedDict() + for event, context in events_and_contexts: + prev_event_context = new_events_and_contexts.get(event.event_id) + if prev_event_context: + if not event.internal_metadata.is_outlier(): + if prev_event_context[0].internal_metadata.is_outlier(): + # To ensure correct ordering we pop, as OrderedDict is + # ordered by first insertion. + new_events_and_contexts.pop(event.event_id, None) + new_events_and_contexts[event.event_id] = (event, context) + else: + new_events_and_contexts[event.event_id] = (event, context) + return list(new_events_and_contexts.values()) + + def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + """Update min_depth for each room + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + backfilled (bool): True if the events were backfilled + """ + depth_updates = {} + for event, context in events_and_contexts: + # Remove the any existing cache entries for the event_ids + txn.call_after(self.store._invalidate_get_event_cache, event.event_id) + if not backfilled: + txn.call_after( + self.store._events_stream_cache.entity_has_changed, + event.room_id, + event.internal_metadata.stream_ordering, + ) + + if not event.internal_metadata.is_outlier() and not context.rejected: + depth_updates[event.room_id] = max( + event.depth, depth_updates.get(event.room_id, event.depth) + ) + + for room_id, depth in iteritems(depth_updates): + self._update_min_depth_for_room_txn(txn, room_id, depth) + + def _update_outliers_txn(self, txn, events_and_contexts): + """Update any outliers with new event info. + + This turns outliers into ex-outliers (unless the new event was + rejected). + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without events which + are already in the events table. + """ + txn.execute( + "SELECT event_id, outlier FROM events WHERE event_id in (%s)" + % (",".join(["?"] * len(events_and_contexts)),), + [event.event_id for event, _ in events_and_contexts], + ) + + have_persisted = {event_id: outlier for event_id, outlier in txn} + + to_remove = set() + for event, context in events_and_contexts: + if event.event_id not in have_persisted: + continue + + to_remove.add(event) + + if context.rejected: + # If the event is rejected then we don't care if the event + # was an outlier or not. + continue + + outlier_persisted = have_persisted[event.event_id] + if not event.internal_metadata.is_outlier() and outlier_persisted: + # We received a copy of an event that we had already stored as + # an outlier in the database. We now have some state at that + # so we need to update the state_groups table with that state. + + # insert into event_to_state_groups. + try: + self._store_event_state_mappings_txn(txn, ((event, context),)) + except Exception: + logger.exception("") + raise + + metadata_json = encode_json(event.internal_metadata.get_dict()) + + sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" + txn.execute(sql, (metadata_json, event.event_id)) + + # Add an entry to the ex_outlier_stream table to replicate the + # change in outlier status to our workers. + stream_order = event.internal_metadata.stream_ordering + state_group_id = context.state_group + self.db.simple_insert_txn( + txn, + table="ex_outlier_stream", + values={ + "event_stream_ordering": stream_order, + "event_id": event.event_id, + "state_group": state_group_id, + }, + ) + + sql = "UPDATE events SET outlier = ? WHERE event_id = ?" + txn.execute(sql, (False, event.event_id)) + + # Update the event_backward_extremities table now that this + # event isn't an outlier any more. + self._update_backward_extremeties(txn, [event]) + + return [ec for ec in events_and_contexts if ec[0] not in to_remove] + + @classmethod + def _delete_existing_rows_txn(cls, txn, events_and_contexts): + if not events_and_contexts: + # nothing to do here + return + + logger.info("Deleting existing") + + for table in ( + "events", + "event_auth", + "event_json", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "event_to_state_groups", + "local_invites", + "state_events", + "rejections", + "redactions", + "room_memberships", + ): + txn.executemany( + "DELETE FROM %s WHERE event_id = ?" % (table,), + [(ev.event_id,) for ev, _ in events_and_contexts], + ) + + for table in ("event_push_actions",): + txn.executemany( + "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), + [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], + ) + + def _store_event_txn(self, txn, events_and_contexts): + """Insert new events into the event and event_json tables + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + """ + + if not events_and_contexts: + # nothing to do here + return + + def event_dict(event): + d = event.get_dict() + d.pop("redacted", None) + d.pop("redacted_because", None) + return d + + self.db.simple_insert_many_txn( + txn, + table="event_json", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "internal_metadata": encode_json( + event.internal_metadata.get_dict() + ), + "json": encode_json(event_dict(event)), + "format_version": event.format_version, + } + for event, _ in events_and_contexts + ], + ) + + self.db.simple_insert_many_txn( + txn, + table="events", + values=[ + { + "stream_ordering": event.internal_metadata.stream_ordering, + "topological_ordering": event.depth, + "depth": event.depth, + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "processed": True, + "outlier": event.internal_metadata.is_outlier(), + "origin_server_ts": int(event.origin_server_ts), + "received_ts": self._clock.time_msec(), + "sender": event.sender, + "contains_url": ( + "url" in event.content + and isinstance(event.content["url"], text_type) + ), + } + for event, _ in events_and_contexts + ], + ) + + for event, _ in events_and_contexts: + if not event.internal_metadata.is_redacted(): + # If we're persisting an unredacted event we go and ensure + # that we mark any redactions that reference this event as + # requiring censoring. + self.db.simple_update_txn( + txn, + table="redactions", + keyvalues={"redacts": event.event_id}, + updatevalues={"have_censored": False}, + ) + + def _store_rejected_events_txn(self, txn, events_and_contexts): + """Add rows to the 'rejections' table for received events which were + rejected + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without the rejected + events. + """ + # Remove the rejected events from the list now that we've added them + # to the events table and the events_json table. + to_remove = set() + for event, context in events_and_contexts: + if context.rejected: + # Insert the event_id into the rejections table + self._store_rejections_txn(txn, event.event_id, context.rejected) + to_remove.add(event) + + return [ec for ec in events_and_contexts if ec[0] not in to_remove] + + def _update_metadata_tables_txn( + self, txn, events_and_contexts, all_events_and_contexts, backfilled + ): + """Update all the miscellaneous tables for new events + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + backfilled (bool): True if the events were backfilled + """ + + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + + if not events_and_contexts: + # nothing to do here + return + + for event, context in events_and_contexts: + if event.type == EventTypes.Redaction and event.redacts is not None: + # Remove the entries in the event_push_actions table for the + # redacted event. + self._remove_push_actions_for_event_id_txn( + txn, event.room_id, event.redacts + ) + + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + + # Update the event_forward_extremities, event_backward_extremities and + # event_edges tables. + self._handle_mult_prev_events( + txn, events=[event for event, _ in events_and_contexts] + ) + + for event, _ in events_and_contexts: + if event.type == EventTypes.Name: + # Insert into the event_search table. + self._store_room_name_txn(txn, event) + elif event.type == EventTypes.Topic: + # Insert into the event_search table. + self._store_room_topic_txn(txn, event) + elif event.type == EventTypes.Message: + # Insert into the event_search table. + self._store_room_message_txn(txn, event) + elif event.type == EventTypes.Redaction and event.redacts is not None: + # Insert into the redactions table. + self._store_redaction(txn, event) + elif event.type == EventTypes.Retention: + # Update the room_retention table. + self._store_retention_policy_for_room_txn(txn, event) + + self._handle_event_relations(txn, event) + + # Store the labels for this event. + labels = event.content.get(EventContentFields.LABELS) + if labels: + self.insert_labels_for_event_txn( + txn, event.event_id, labels, event.room_id, event.depth + ) + + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + + # Insert into the room_memberships table. + self._store_room_members_txn( + txn, + [ + event + for event, _ in events_and_contexts + if event.type == EventTypes.Member + ], + backfilled=backfilled, + ) + + # Insert event_reference_hashes table. + self._store_event_reference_hashes_txn( + txn, [event for event, _ in events_and_contexts] + ) + + state_events_and_contexts = [ + ec for ec in events_and_contexts if ec[0].is_state() + ] + + state_values = [] + for event, context in state_events_and_contexts: + vals = { + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + + # TODO: How does this work with backfilling? + if hasattr(event, "replaces_state"): + vals["prev_state"] = event.replaces_state + + state_values.append(vals) + + self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) + + # Prefill the event cache + self._add_to_cache(txn, events_and_contexts) + + def _add_to_cache(self, txn, events_and_contexts): + to_prefill = [] + + rows = [] + N = 200 + for i in range(0, len(events_and_contexts), N): + ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} + if not ev_map: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM events as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE " + ) + + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", list(ev_map) + ) + + txn.execute(sql + clause, args) + rows = self.db.cursor_to_dict(txn) + for row in rows: + event = ev_map[row["event_id"]] + if not row["rejects"] and not row["redacts"]: + to_prefill.append( + _EventCacheEntry(event=event, redacted_event=None) + ) + + def prefill(): + for cache_entry in to_prefill: + self.store._get_event_cache.prefill( + (cache_entry[0].event_id,), cache_entry + ) + + txn.call_after(prefill) + + def _store_redaction(self, txn, event): + # invalidate the cache for the redacted event + txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + + self.db.simple_insert_txn( + txn, + table="redactions", + values={ + "event_id": event.event_id, + "redacts": event.redacts, + "received_ts": self._clock.time_msec(), + }, + ) + + def insert_labels_for_event_txn( + self, txn, event_id, labels, room_id, topological_ordering + ): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. + + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + room_id (str): The ID of the room the event was sent to. + topological_ordering (int): The position of the event in the room's topology. + """ + return self.db.simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": room_id, + "topological_ordering": topological_ordering, + } + for label in labels + ], + ) + + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self.db.simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + def _store_event_reference_hashes_txn(self, txn, events): + """Store a hash for a PDU + Args: + txn (cursor): + events (list): list of Events. + """ + + vals = [] + for event in events: + ref_alg, ref_hash_bytes = compute_event_reference_hash(event) + vals.append( + { + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": memoryview(ref_hash_bytes), + } + ) + + self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self.db.simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ], + ) + + for event in events: + txn.call_after( + self.store._membership_stream_cache.entity_has_changed, + event.state_key, + event.internal_metadata.stream_ordering, + ) + txn.call_after( + self.store.get_invited_rooms_for_local_user.invalidate, + (event.state_key,), + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. If the event is an + # outlier it is only current if its an "out of band membership", + # like a remote invite or a rejection of a remote invite. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_out_of_band_membership() + ) + is_mine = self.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self.db.simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + }, + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) + + # We also update the `local_current_membership` table with + # latest invite info. This will usually get updated by the + # `current_state_events` handling, unless its an outlier. + if event.internal_metadata.is_outlier(): + # This should only happen for out of band memberships, so + # we add a paranoia check. + assert event.internal_metadata.is_out_of_band_membership() + + self.db.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) + + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self.db.simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) + + txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) + + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self.db.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} + ) + + def _store_room_topic_txn(self, txn, event): + if hasattr(event, "content") and "topic" in event.content: + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"] + ) + + def _store_room_name_txn(self, txn, event): + if hasattr(event, "content") and "name" in event.content: + self.store_event_search_txn( + txn, event, "content.name", event.content["name"] + ) + + def _store_room_message_txn(self, txn, event): + if hasattr(event, "content") and "body" in event.content: + self.store_event_search_txn( + txn, event, "content.body", event.content["body"] + ) + + def _store_retention_policy_for_room_txn(self, txn, event): + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content + ): + if ( + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) + ): + # Ignore the event if one of the value isn't an integer. + return + + self.db.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) + + self.store._invalidate_cache_and_stream( + txn, self.store.get_retention_policy_for_room, (event.room_id,) + ) + + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table + + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): + """ + self.store.store_search_entries_txn( + txn, + ( + SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ), + ), + ) + + def _set_push_actions_for_event_and_users_txn( + self, txn, events_and_contexts, all_events_and_contexts + ): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany( + sql, + ( + ( + event.room_id, + event.internal_metadata.stream_ordering, + event.depth, + event.event_id, + ) + for event, _ in events_and_contexts + ), + ) + + for event, _ in events_and_contexts: + user_ids = self.db.simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid), + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ((event.event_id,) for event, _ in all_events_and_contexts), + ) + + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + # Sad that we have to blow away the cache for the whole room here + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id,), + ) + txn.execute( + "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", + (room_id, event_id), + ) + + def _store_rejections_txn(self, txn, event_id, reason): + self.db.simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + }, + ) + + def _store_event_state_mappings_txn( + self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] + ): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.state_group_before_event + continue + + state_groups[event.event_id] = context.state_group + + self.db.simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + {"state_group": state_group_id, "event_id": event_id} + for event_id, state_group_id in iteritems(state_groups) + ], + ) + + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self.store._get_state_group_for_event.prefill, + (event_id,), + state_group_id, + ) + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self.store._get_min_depth_interaction(txn, room_id) + + if min_depth is not None and depth >= min_depth: + return + + self.db.simple_upsert_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + values={"min_depth": depth}, + ) + + def _handle_mult_prev_events(self, txn, events): + """ + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self.db.simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, + } + for ev in events + for e_id in ev.prev_event_ids() + ], + ) + + self._update_backward_extremeties(txn, events) + + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. + """ + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) + + txn.executemany( + query, + [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events + for e_id in ev.prev_event_ids() + if not ev.internal_metadata.is_outlier() + ], + ) + + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) + for ev in events + if not ev.internal_metadata.is_outlier() + ], + ) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, (stream_ordering, True, room_id, user_id)) + + # We also clear this entry from `local_current_membership`. + # Ideally we'd point to a leave event, but we don't have one, so + # nevermind. + self.db.simple_delete_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + ) + + with self._stream_id_gen.get_next() as stream_ordering: + await self.db.runInteraction("locally_reject_invite", f, stream_ordering) + + return stream_ordering diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 6587f31e2b..f54c8b1ee0 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -21,29 +21,31 @@ from canonicaljson import json from twisted.internet import defer -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.api.constants import EventContentFields +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database logger = logging.getLogger(__name__) -class EventsBackgroundUpdatesStore(BackgroundUpdateStore): +class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -54,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -63,10 +65,39 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): psql_only=True, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) + self.db.updates.register_background_update_handler( + "redactions_received_ts", self._redactions_received_ts + ) + + # This index gets deleted in `event_fix_redactions_bytes` update + self.db.updates.register_background_index_update( + "event_fix_redactions_bytes_create_index", + index_name="redactions_censored_redacts", + table="redactions", + columns=["redacts"], + where_clause="have_censored", + ) + + self.db.updates.register_background_update_handler( + "event_fix_redactions_bytes", self._event_fix_redactions_bytes + ) + + self.db.updates.register_background_update_handler( + "event_store_labels", self._event_store_labels + ) + + self.db.updates.register_background_index_update( + "redactions_have_censored_ts_idx", + index_name="redactions_have_censored_ts", + table="redactions", + columns=["received_ts"], + where_clause="NOT have_censored", + ) + @defer.inlineCallbacks def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] @@ -122,18 +153,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + ) return result @@ -166,7 +199,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self._simple_select_many_txn( + ev_rows = self.db.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -199,18 +232,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows_to_update), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + yield self.db.updates._end_background_update( + self.EVENT_ORIGIN_SERVER_TS_NAME + ) return result @@ -308,12 +343,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): INNER JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) WHERE - prev_event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_check), + NOT events.outlier + AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", to_check ) - txn.execute(sql, to_check) + txn.execute(sql + clause, list(args)) for prev_event_id, event_id, metadata, rejected in txn: if event_id in graph: @@ -342,7 +378,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): to_delete.intersection_update(original_set) - deleted = self._simple_delete_many_txn( + deleted = self.db.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -358,7 +394,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="events", column="event_id", @@ -366,13 +402,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): keyvalues={}, retcols=("room_id",), ) - room_ids = set(row["room_id"] for row in rows) + room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -382,18 +418,172 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(original_set) - num_handled = yield self.runInteraction( + num_handled = yield self.db.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: - yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + yield self.db.updates._end_background_update( + self.DELETE_SOFT_FAILED_EXTREMITIES + ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.runInteraction( + yield self.db.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) return num_handled + + @defer.inlineCallbacks + def _redactions_received_ts(self, progress, batch_size): + """Handles filling out the `received_ts` column in redactions. + """ + last_event_id = progress.get("last_event_id", "") + + def _redactions_received_ts_txn(txn): + # Fetch the set of event IDs that we want to update + sql = """ + SELECT event_id FROM redactions + WHERE event_id > ? + ORDER BY event_id ASC + LIMIT ? + """ + + txn.execute(sql, (last_event_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + (upper_event_id,) = rows[-1] + + # Update the redactions with the received_ts. + # + # Note: Not all events have an associated received_ts, so we + # fallback to using origin_server_ts. If we for some reason don't + # have an origin_server_ts, lets just use the current timestamp. + # + # We don't want to leave it null, as then we'll never try and + # censor those redactions. + sql = """ + UPDATE redactions + SET received_ts = ( + SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events + WHERE events.event_id = redactions.event_id + ) + WHERE ? <= event_id AND event_id <= ? + """ + + txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) + + self.db.updates._background_update_progress_txn( + txn, "redactions_received_ts", {"last_event_id": upper_event_id} + ) + + return len(rows) + + count = yield self.db.runInteraction( + "_redactions_received_ts", _redactions_received_ts_txn + ) + + if not count: + yield self.db.updates._end_background_update("redactions_received_ts") + + return count + + @defer.inlineCallbacks + def _event_fix_redactions_bytes(self, progress, batch_size): + """Undoes hex encoded censored redacted event JSON. + """ + + def _event_fix_redactions_bytes_txn(txn): + # This update is quite fast due to new index. + txn.execute( + """ + UPDATE event_json + SET + json = convert_from(json::bytea, 'utf8') + FROM redactions + WHERE + redactions.have_censored + AND event_json.event_id = redactions.redacts + AND json NOT LIKE '{%'; + """ + ) + + txn.execute("DROP INDEX redactions_censored_redacts") + + yield self.db.runInteraction( + "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn + ) + + yield self.db.updates._end_background_update("event_fix_redactions_bytes") + + return 1 + + @defer.inlineCallbacks + def _event_store_labels(self, progress, batch_size): + """Background update handler which will store labels for existing events.""" + last_event_id = progress.get("last_event_id", "") + + def _event_store_labels_txn(txn): + txn.execute( + """ + SELECT event_id, json FROM event_json + LEFT JOIN event_labels USING (event_id) + WHERE event_id > ? AND label IS NULL + ORDER BY event_id LIMIT ? + """, + (last_event_id, batch_size), + ) + + results = list(txn) + + nbrows = 0 + last_row_event_id = "" + for (event_id, event_json_raw) in results: + try: + event_json = json.loads(event_json_raw) + + self.db.simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": event_json["room_id"], + "topological_ordering": event_json["depth"], + } + for label in event_json["content"].get( + EventContentFields.LABELS, [] + ) + if isinstance(label, str) + ], + ) + except Exception as e: + logger.warning( + "Unable to load event %s (no labels will be imported): %s", + event_id, + e, + ) + + nbrows += 1 + last_row_event_id = event_id + + self.db.updates._background_update_progress_txn( + txn, "event_store_labels", {"last_event_id": last_row_event_id} + ) + + return nbrows + + num_rows = yield self.db.runInteraction( + desc="event_store_labels", func=_event_store_labels_txn + ) + + if not num_rows: + yield self.db.updates._end_background_update("event_store_labels") + + return num_rows diff --git a/synapse/storage/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index c6fa7f82fd..213d69100a 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -17,26 +17,35 @@ from __future__ import division import itertools import logging +import threading from collections import namedtuple +from typing import List, Optional, Tuple from canonicaljson import json +from constantly import NamedConstant, Names from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError -from synapse.api.room_versions import EventFormatVersions -from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.api.errors import NotFoundError, SynapseError +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) +from synapse.events import make_event_from_dict from synapse.events.utils import prune_event -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util import batch_iter +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) @@ -53,7 +62,64 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + class EventsWorkerStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) + + if hs.config.worker.writers.events == hs.get_instance_name(): + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + extra_tables=[("local_invites", "stream_id")], + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + + self._get_event_cache = Cache( + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, + ) + + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + def get_received_ts(self, event_id): """Get received_ts (when it was persisted) for the event. @@ -66,7 +132,7 @@ class EventsWorkerStore(SQLBaseStore): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -105,32 +171,41 @@ class EventsWorkerStore(SQLBaseStore): return ts - return self.runInteraction( + return self.db.runInteraction( "get_approximate_received_ts", _get_approximate_received_ts_txn ) @defer.inlineCallbacks def get_event( self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, ): """Get an event from the database by event_id. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_id: The event_id of the event to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (behave as per allow_none + if the event is redacted) + + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if + + allow_rejected: If True, return rejected events. Otherwise, + behave as per allow_none. + + allow_none: If True, return None if no event found, if False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. + + check_room_id: if not None, check the room of the found event. If there is a mismatch, behave as per allow_none. Returns: @@ -141,7 +216,7 @@ class EventsWorkerStore(SQLBaseStore): events = yield self.get_events_as_list( [event_id], - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -160,27 +235,34 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (omit them from the response) + + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + + allow_rejected: If True, return rejected events. Otherwise, + omits rejeted events from the response. Returns: Deferred : Dict from event_id to event. """ events = yield self.get_events_as_list( event_ids, - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -190,21 +272,29 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events_as_list( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database and return in a list in the same order as given by `event_ids` arg. + Unknown events will be omitted from the response. + Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (omit them from the response) + + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + + allow_rejected: If True, return rejected events. Otherwise, + omits rejected events from the response. Returns: Deferred[list[EventBase]]: List of events fetched from the database. The @@ -238,6 +328,20 @@ class EventsWorkerStore(SQLBaseStore): # we have to recheck auth now. if not allow_rejected and entry.event.type == EventTypes.Redaction: + if entry.event.redacts is None: + # A redacted redaction doesn't have a `redacts` key, in + # which case lets just withhold the event. + # + # Note: Most of the time if the redactions has been + # redacted we still have the un-redacted event in the DB + # and so we'll still see the `redacts` key. However, this + # isn't always true e.g. if we have censored the event. + logger.debug( + "Withholding redaction event %s as we don't have redacts key", + event_id, + ) + continue + redacted_event_id = entry.event.redacts event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) @@ -292,10 +396,14 @@ class EventsWorkerStore(SQLBaseStore): # Update the cache to save doing the checks again. entry.event.internal_metadata.recheck_redaction = False - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event events.append(event) @@ -319,9 +427,14 @@ class EventsWorkerStore(SQLBaseStore): If events are pulled from the database, they will be cached for future lookups. + Unknown events are omitted from the response. + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events + + allow_rejected (bool): Whether to include rejected events. If False, + rejected events are omitted from the response. Returns: Deferred[Dict[str, _EventCacheEntry]]: @@ -334,7 +447,7 @@ class EventsWorkerStore(SQLBaseStore): missing_events_ids = [e for e in event_ids if e not in event_entry_map] if missing_events_ids: - log_ctx = LoggingContext.current_context() + log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) # Note that _get_events_from_db is also responsible for turning db rows @@ -422,11 +535,11 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - events_to_fetch = set( + events_to_fetch = { event_id for events, _ in event_list for event_id in events - ) + } - row_dict = self._new_transaction( + row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -456,9 +569,13 @@ class EventsWorkerStore(SQLBaseStore): Returned events will be added to the cache for future lookups. + Unknown events are omitted from the response. + Args: event_ids (Iterable[str]): The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events + + allow_rejected (bool): Whether to include rejected events. If False, + rejected events are omitted from the response. Returns: Deferred[Dict[str, _EventCacheEntry]]: @@ -504,15 +621,56 @@ class EventsWorkerStore(SQLBaseStore): # of a event format version, so it must be a V1 event. format_version = EventFormatVersions.V1 - original_ev = event_type_from_format_version(format_version)( + room_version_id = row["room_version_id"] + + if not room_version_id: + # this should only happen for out-of-band membership events + if not internal_metadata.get("out_of_band_membership"): + logger.warning( + "Room %s for event %s is unknown", d["room_id"], event_id + ) + continue + + # take a wild stab at the room version based on the event format + if format_version == EventFormatVersions.V1: + room_version = RoomVersions.V1 + elif format_version == EventFormatVersions.V2: + room_version = RoomVersions.V3 + else: + room_version = RoomVersions.V5 + else: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + logger.error( + "Event %s in room %s has unknown room version %s", + event_id, + d["room_id"], + room_version_id, + ) + continue + + if room_version.event_format != format_version: + logger.error( + "Event %s in room %s with version %s has wrong format: " + "expected %s, was %s", + event_id, + d["room_id"], + room_version_id, + room_version.event_format, + format_version, + ) + continue + + original_ev = make_event_from_dict( event_dict=d, + room_version=room_version, internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) event_map[event_id] = original_ev - # finally, we can decide whether each one nededs redacting, and build + # finally, we can decide whether each one needs redacting, and build # the cache entries. result_map = {} for event_id, original_ev in event_map.items(): @@ -558,7 +716,7 @@ class EventsWorkerStore(SQLBaseStore): if should_start: run_as_background_process( - "fetch_events", self.runWithConnection, self._do_fetch + "fetch_events", self.db.runWithConnection, self._do_fetch ) logger.debug("Loading %d events: %s", len(events), events) @@ -585,6 +743,12 @@ class EventsWorkerStore(SQLBaseStore): of EventFormatVersions. 'None' means the event predates EventFormatVersions (so the event is format V1). + * room_version_id (str|None): The version of the room which contains the event. + Hopefully one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + * rejected_reason (str|None): if the event was rejected, the reason why. @@ -600,19 +764,24 @@ class EventsWorkerStore(SQLBaseStore): """ event_dict = {} for evs in batch_iter(event_ids, 200): - sql = ( - "SELECT " - " e.event_id, " - " e.internal_metadata," - " e.json," - " e.format_version, " - " rej.reason " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) - - txn.execute(sql, evs) + sql = """\ + SELECT + e.event_id, + e.internal_metadata, + e.json, + e.format_version, + r.room_version, + rej.reason + FROM event_json as e + LEFT JOIN rooms r USING (room_id) + LEFT JOIN rejections as rej USING (event_id) + WHERE """ + + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", evs + ) + + txn.execute(sql + clause, args) for row in txn: event_id = row[0] @@ -621,16 +790,17 @@ class EventsWorkerStore(SQLBaseStore): "internal_metadata": row[1], "json": row[2], "format_version": row[3], - "rejected_reason": row[4], + "room_version_id": row[4], + "rejected_reason": row[5], "redactions": [], } # check for redactions - redactions_sql = ( - "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)" - ) % (",".join(["?"] * len(evs)),) + redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " - txn.execute(redactions_sql, evs) + clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) + + txn.execute(redactions_sql + clause, args) for (redacter, redacted) in txn: d = event_dict.get(redacted) @@ -715,7 +885,7 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -724,7 +894,7 @@ class EventsWorkerStore(SQLBaseStore): desc="have_events_in_timeline", ) - return set(r["event_id"] for r in rows) + return {r["event_id"] for r in rows} @defer.inlineCallbacks def have_seen_events(self, event_ids): @@ -739,52 +909,21 @@ class EventsWorkerStore(SQLBaseStore): results = set() def have_seen_events_txn(txn, chunk): - sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % ( - ",".join("?" * len(chunk)), + sql = "SELECT event_id FROM events as e WHERE " + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", chunk ) - txn.execute(sql, chunk) + txn.execute(sql + clause, args) for (event_id,) in txn: results.add(event_id) # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) - return results - - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" + yield self.db.runInteraction( + "have_seen_events", have_seen_events_txn, chunk ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_seen_events_with_rejections", f) + return results def _get_total_state_event_counts_txn(self, txn, room_id): """ @@ -810,7 +949,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_total_state_event_counts", self._get_total_state_event_counts_txn, room_id, @@ -835,7 +974,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, @@ -862,3 +1001,343 @@ class EventsWorkerStore(SQLBaseStore): complexity_v1 = round(state_events / 500, 2) return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + """Returns new events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_ex_outlier_stream_rows(self, last_id, current_id): + """Returns de-outliered events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_ex_outlier_stream_rows_txn(txn): + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering ASC" + ) + + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.db.runInteraction( + "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, target_row_count)) + return txn.fetchall() + + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token + ) + + return rows, to_token, True + + @cached(num_args=5, max_entries=10) + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + have_backfill_events = last_backfill_id != current_backfill_id + have_forward_events = last_forward_id != current_forward_id + + if not have_backfill_events and not have_forward_events: + return defer.succeed(AllNewEventsResult([], [], [], [], [])) + + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + if have_forward_events: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() + else: + new_forward_events = [] + forward_ex_outliers = [] + + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + if have_backfill_events: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() + else: + new_backfill_events = [] + backward_ex_outliers = [] + + return AllNewEventsResult( + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, + ) + + return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) + + async def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) + return (to_1, so_1) > (to_2, so_2) + + @cachedInlineCallbacks(max_entries=5000) + def get_event_ordering(self, event_id): + res = yield self.db.simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + return (int(res["topological_ordering"]), int(res["stream_ordering"])) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.db.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + + +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) diff --git a/synapse/storage/filtering.py b/synapse/storage/data_stores/main/filtering.py index 23b48f6cea..342d6622a4 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -16,10 +16,9 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.util.caches.descriptors import cachedInlineCallbacks -from ._base import SQLBaseStore, db_to_json - class FilteringStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) @@ -31,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self._simple_select_one_onecol( + def_json = yield self.db.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -51,12 +50,12 @@ class FilteringStore(SQLBaseStore): "SELECT filter_id FROM user_filters " "WHERE user_id = ? AND filter_json = ?" ) - txn.execute(sql, (user_localpart, def_json)) + txn.execute(sql, (user_localpart, bytearray(def_json))) filter_id_response = txn.fetchone() if filter_id_response is not None: return filter_id_response[0] - sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?" + sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) max_id = txn.fetchone()[0] if max_id is None: @@ -68,8 +67,8 @@ class FilteringStore(SQLBaseStore): "INSERT INTO user_filters (user_id, filter_id, filter_json)" "VALUES(?, ?, ?)" ) - txn.execute(sql, (user_localpart, filter_id, def_json)) + txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) return filter_id - return self.runInteraction("add_user_filter", _do_txn) + return self.db.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/group_server.py b/synapse/storage/data_stores/main/group_server.py index 15b01c6958..fb1361f1c1 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -19,8 +19,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null @@ -28,23 +27,9 @@ _DEFAULT_CATEGORY_ID = "" _DEFAULT_ROLE_ID = "" -class GroupServerStore(SQLBaseStore): - def set_group_join_policy(self, group_id, join_policy): - """Set the join policy of a group. - - join_policy can be one of: - * "invite" - * "open" - """ - return self._simple_update_one( - table="groups", - keyvalues={"group_id": group_id}, - updatevalues={"join_policy": join_policy}, - desc="set_group_join_policy", - ) - +class GroupServerWorkerStore(SQLBaseStore): def get_group(self, group_id): - return self._simple_select_one( + return self.db.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -66,7 +51,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.db.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -76,31 +61,85 @@ class GroupServerStore(SQLBaseStore): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id, include_private=False): + def get_rooms_in_group(self, group_id: str, include_private: bool = False): + """Retrieve the rooms that belong to a given group. Does not return rooms that + lack members. + + Args: + group_id: The ID of the group to query for rooms + include_private: Whether to return private rooms in results + + Returns: + Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the + form of: + + { + "room_id": "!a_room_id:example.com", # The ID of the room + "is_public": False # Whether this is a public room or not + } + """ # TODO: Pagination - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True + def _get_rooms_in_group_txn(txn): + sql = """ + SELECT room_id, is_public FROM group_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + args = [group_id, False] - return self._simple_select_list( - table="group_rooms", - keyvalues=keyvalues, - retcols=("room_id", "is_public"), - desc="get_rooms_in_group", - ) + if not include_private: + sql += " AND is_public = ?" + args += [True] + + txn.execute(sql, args) + + return [ + {"room_id": room_id, "is_public": is_public} + for room_id, is_public in txn + ] - def get_rooms_for_summary_by_category(self, group_id, include_private=False): + return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) + + def get_rooms_for_summary_by_category( + self, group_id: str, include_private: bool = False, + ): """Get the rooms and categories that should be included in a summary request - Returns ([rooms], [categories]) + Args: + group_id: The ID of the group to query the summary for + include_private: Whether to return private rooms in results + + Returns: + Deferred[Tuple[List, Dict]]: A tuple containing: + + * A list of dictionaries with the keys: + * "room_id": str, the room ID + * "is_public": bool, whether the room is public + * "category_id": str|None, the category ID if set, else None + * "order": int, the sort order of rooms + + * A dictionary with the key: + * category_id (str): a dictionary with the keys: + * "is_public": bool, whether the category is public + * "profile": str, the category profile + * "order": int, the sort order of rooms in this category """ def _get_rooms_for_summary_txn(txn): @@ -112,13 +151,23 @@ class GroupServerStore(SQLBaseStore): SELECT room_id, is_public, category_id, room_order FROM group_summary_rooms WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) """ if not include_private: sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) + txn.execute(sql, (group_id, False, True)) else: - txn.execute(sql, (group_id,)) + txn.execute(sql, (group_id, False)) rooms = [ { @@ -154,10 +203,372 @@ class GroupServerStore(SQLBaseStore): return rooms, categories - return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) + return self.db.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self.db.simple_select_list( + table="group_room_categories", + keyvalues={"group_id": group_id}, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + return { + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self.db.simple_select_one( + table="group_room_categories", + keyvalues={"group_id": group_id, "category_id": category_id}, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = json.loads(category["profile"]) + + return category + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self.db.simple_select_list( + table="group_roles", + keyvalues={"group_id": group_id}, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + return { + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + } + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self.db.simple_select_one( + table="group_roles", + keyvalues={"group_id": group_id, "role_id": role_id}, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = json.loads(role["profile"]) + + return role + + def get_local_groups_for_room(self, room_id): + """Get all of the local group that contain a given room + Args: + room_id (str): The ID of a room + Returns: + Deferred[list[str]]: A twisted.Deferred containing a list of group ids + containing this room + """ + return self.db.simple_select_onecol( + table="group_rooms", + keyvalues={"room_id": room_id}, + retcol="group_id", + desc="get_local_groups_for_room", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + + def _get_users_for_summary_txn(txn): + keyvalues = {"group_id": group_id} + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + + return self.db.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self.db.simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self.db.simple_select_one_onecol( + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self.db.simple_select_one_onecol( + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + + def _get_users_membership_in_group_txn(txn): + row = self.db.simple_select_one_txn( + txn, + table="group_users", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self.db.simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcol="user_id", + allow_none=True, + ) + + if row: + return {"membership": "invite"} + + return {} + + return self.db.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self.db.simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.db.cursor_to_dict(txn) + + return self.db.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self.db.simple_select_one( + table="group_attestations_remote", + keyvalues={"group_id": group_id, "user_id": user_id}, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + return json.loads(row["attestation_json"]) + + return None + + def get_joined_groups(self, user_id): + return self.db.simple_select_onecol( + table="local_group_membership", + keyvalues={"user_id": user_id, "membership": "join"}, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": json.loads(row[3]), + } + for row in txn + ] + + return self.db.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token + ) + if not has_changed: + return defer.succeed([]) + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token)) + return [ + { + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": json.loads(content_json), + } + for group_id, membership, gtype, content_json in txn + ] + + return self.db.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn + ) + + def get_all_groups_changes(self, from_token, to_token, limit): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_any_entity_changed( + from_token + ) + if not has_changed: + return defer.succeed([]) + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return [ + (stream_id, group_id, user_id, gtype, json.loads(content_json)) + for stream_id, group_id, user_id, gtype, content_json in txn + ] + + return self.db.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn + ) + + +class GroupServerStore(GroupServerWorkerStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self.db.simple_update_one( + table="groups", + keyvalues={"group_id": group_id}, + updatevalues={"join_policy": join_policy}, + desc="set_group_join_policy", + ) def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -181,7 +592,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self._simple_select_one_onecol_txn( + room_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -194,7 +605,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -205,7 +616,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -225,7 +636,7 @@ class GroupServerStore(SQLBaseStore): (group_id, category_id, group_id, category_id), ) - existing = self._simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -250,7 +661,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} @@ -258,7 +669,7 @@ class GroupServerStore(SQLBaseStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -272,7 +683,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -288,7 +699,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self._simple_delete( + return self.db.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -298,36 +709,6 @@ class GroupServerStore(SQLBaseStore): desc="remove_room_from_summary", ) - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self._simple_select_list( - table="group_room_categories", - keyvalues={"group_id": group_id}, - retcols=("category_id", "is_public", "profile"), - desc="get_group_categories", - ) - - return { - row["category_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), - } - for row in rows - } - - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self._simple_select_one( - table="group_room_categories", - keyvalues={"group_id": group_id, "category_id": category_id}, - retcols=("is_public", "profile"), - desc="get_group_category", - ) - - category["profile"] = json.loads(category["profile"]) - - return category - def upsert_group_category(self, group_id, category_id, profile, is_public): """Add/update room category for group """ @@ -344,7 +725,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.db.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -353,42 +734,12 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_category(self, group_id, category_id): - return self._simple_delete( + return self.db.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", ) - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self._simple_select_list( - table="group_roles", - keyvalues={"group_id": group_id}, - retcols=("role_id", "is_public", "profile"), - desc="get_group_roles", - ) - - return { - row["role_id"]: { - "is_public": row["is_public"], - "profile": json.loads(row["profile"]), - } - for row in rows - } - - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self._simple_select_one( - table="group_roles", - keyvalues={"group_id": group_id, "role_id": role_id}, - retcols=("is_public", "profile"), - desc="get_group_role", - ) - - role["profile"] = json.loads(role["profile"]) - - return role - def upsert_group_role(self, group_id, role_id, profile, is_public): """Add/remove user role """ @@ -405,7 +756,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.db.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -414,14 +765,14 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_role(self, group_id, role_id): - return self._simple_delete( + return self.db.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -445,7 +796,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self._simple_select_one_onecol_txn( + user_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -458,7 +809,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -469,7 +820,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -489,7 +840,7 @@ class GroupServerStore(SQLBaseStore): (group_id, role_id, group_id, role_id), ) - existing = self._simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -510,7 +861,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} @@ -518,7 +869,7 @@ class GroupServerStore(SQLBaseStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -532,7 +883,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_users", values={ @@ -548,158 +899,21 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self._simple_delete( + return self.db.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", ) - def get_users_for_summary_by_role(self, group_id, include_private=False): - """Get the users and roles that should be included in a summary request - - Returns ([users], [roles]) - """ - - def _get_users_for_summary_txn(txn): - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True - - sql = """ - SELECT user_id, is_public, role_id, user_order - FROM group_summary_users - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - users = [ - { - "user_id": row[0], - "is_public": row[1], - "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, - "order": row[3], - } - for row in txn - ] - - sql = """ - SELECT role_id, is_public, profile, role_order - FROM group_summary_roles - INNER JOIN group_roles USING (group_id, role_id) - WHERE group_id = ? - """ - - if not include_private: - sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) - else: - txn.execute(sql, (group_id,)) - - roles = { - row[0]: { - "is_public": row[1], - "profile": json.loads(row[2]), - "order": row[3], - } - for row in txn - } - - return users, roles - - return self.runInteraction( - "get_users_for_summary_by_role", _get_users_for_summary_txn - ) - - def is_user_in_group(self, user_id, group_id): - return self._simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="is_user_in_group", - ).addCallback(lambda r: bool(r)) - - def is_user_admin_in_group(self, group_id, user_id): - return self._simple_select_one_onecol( - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="is_admin", - allow_none=True, - desc="is_user_admin_in_group", - ) - def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self._simple_insert( + return self.db.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", ) - def is_user_invited_to_local_group(self, group_id, user_id): - """Has the group server invited a user? - """ - return self._simple_select_one_onecol( - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - desc="is_user_invited_to_local_group", - allow_none=True, - ) - - def get_users_membership_info_in_group(self, group_id, user_id): - """Get a dict describing the membership of a user in a group. - - Example if joined: - - { - "membership": "join", - "is_public": True, - "is_privileged": False, - } - - Returns an empty dict if the user is not join/invite/etc - """ - - def _get_users_membership_in_group_txn(txn): - row = self._simple_select_one_txn( - txn, - table="group_users", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("is_admin", "is_public"), - allow_none=True, - ) - - if row: - return { - "membership": "join", - "is_public": row["is_public"], - "is_privileged": row["is_admin"], - } - - row = self._simple_select_one_onecol_txn( - txn, - table="group_invites", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcol="user_id", - allow_none=True, - ) - - if row: - return {"membership": "invite"} - - return {} - - return self.runInteraction( - "get_users_membership_info_in_group", _get_users_membership_in_group_txn - ) - def add_user_to_group( self, group_id, @@ -724,7 +938,7 @@ class GroupServerStore(SQLBaseStore): """ def _add_user_to_group_txn(txn): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_users", values={ @@ -735,14 +949,14 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -752,7 +966,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -763,49 +977,49 @@ class GroupServerStore(SQLBaseStore): }, ) - return self.runInteraction("add_user_to_group", _add_user_to_group_txn) + return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + return self.db.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self._simple_update( + return self.db.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -814,36 +1028,26 @@ class GroupServerStore(SQLBaseStore): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) - def get_publicised_groups_for_user(self, user_id): - """Get all groups a user is publicising - """ - return self._simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, - retcol="group_id", - desc="get_publicised_groups_for_user", - ) - def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self._simple_update_one( + return self.db.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -879,12 +1083,12 @@ class GroupServerStore(SQLBaseStore): def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_membership", values={ @@ -897,7 +1101,7 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_updates", values={ @@ -916,7 +1120,7 @@ class GroupServerStore(SQLBaseStore): if membership == "join": if local_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -926,7 +1130,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -937,12 +1141,12 @@ class GroupServerStore(SQLBaseStore): }, ) else: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -951,7 +1155,7 @@ class GroupServerStore(SQLBaseStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, @@ -962,7 +1166,7 @@ class GroupServerStore(SQLBaseStore): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self._simple_insert( + yield self.db.simple_insert( table="groups", values={ "group_id": group_id, @@ -977,33 +1181,17 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self._simple_update_one( + yield self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, desc="update_group_profile", ) - def get_attestations_need_renewals(self, valid_until_ms): - """Get all attestations that need to be renewed until givent time - """ - - def _get_attestations_need_renewals_txn(txn): - sql = """ - SELECT group_id, user_id FROM group_attestations_renewals - WHERE valid_until_ms <= ? - """ - txn.execute(sql, (valid_until_ms,)) - return self.cursor_to_dict(txn) - - return self.runInteraction( - "get_attestations_need_renewals", _get_attestations_need_renewals_txn - ) - def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self._simple_update_one( + return self.db.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1013,7 +1201,7 @@ class GroupServerStore(SQLBaseStore): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self._simple_update_one( + return self.db.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1032,118 +1220,12 @@ class GroupServerStore(SQLBaseStore): group_id (str) user_id (str) """ - return self._simple_delete( + return self.db.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", ) - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): - """Get the attestation that proves the remote agrees that the user is - in the group. - """ - row = yield self._simple_select_one( - table="group_attestations_remote", - keyvalues={"group_id": group_id, "user_id": user_id}, - retcols=("valid_until_ms", "attestation_json"), - desc="get_remote_attestation", - allow_none=True, - ) - - now = int(self._clock.time_msec()) - if row and now < row["valid_until_ms"]: - return json.loads(row["attestation_json"]) - - return None - - def get_joined_groups(self, user_id): - return self._simple_select_onecol( - table="local_group_membership", - keyvalues={"user_id": user_id, "membership": "join"}, - retcol="group_id", - desc="get_joined_groups", - ) - - def get_all_groups_for_user(self, user_id, now_token): - def _get_all_groups_for_user_txn(txn): - sql = """ - SELECT group_id, type, membership, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND membership != 'leave' - AND stream_id <= ? - """ - txn.execute(sql, (user_id, now_token)) - return [ - { - "group_id": row[0], - "type": row[1], - "membership": row[2], - "content": json.loads(row[3]), - } - for row in txn - ] - - return self.runInteraction( - "get_all_groups_for_user", _get_all_groups_for_user_txn - ) - - def get_groups_changes_for_user(self, user_id, from_token, to_token): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_entity_changed( - user_id, from_token - ) - if not has_changed: - return [] - - def _get_groups_changes_for_user_txn(txn): - sql = """ - SELECT group_id, membership, type, u.content - FROM local_group_updates AS u - INNER JOIN local_group_membership USING (group_id, user_id) - WHERE user_id = ? AND ? < stream_id AND stream_id <= ? - """ - txn.execute(sql, (user_id, from_token, to_token)) - return [ - { - "group_id": group_id, - "membership": membership, - "type": gtype, - "content": json.loads(content_json), - } - for group_id, membership, gtype, content_json in txn - ] - - return self.runInteraction( - "get_groups_changes_for_user", _get_groups_changes_for_user_txn - ) - - def get_all_groups_changes(self, from_token, to_token, limit): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_any_entity_changed( - from_token - ) - if not has_changed: - return [] - - def _get_all_groups_changes_txn(txn): - sql = """ - SELECT stream_id, group_id, user_id, type, content - FROM local_group_updates - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return [ - (stream_id, group_id, user_id, gtype, json.loads(content_json)) - for stream_id, group_id, user_id, gtype, content_json in txn - ] - - return self.runInteraction( - "get_all_groups_changes", _get_all_groups_changes_txn - ) - def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() @@ -1174,8 +1256,8 @@ class GroupServerStore(SQLBaseStore): ] for table in tables: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction("delete_group", _delete_group_txn) + return self.db.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py new file mode 100644 index 0000000000..4e1642a27a --- /dev/null +++ b/synapse/storage/data_stores/main/keys.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +from signedjson.key import decode_verify_key_bytes + +from synapse.storage._base import SQLBaseStore +from synapse.storage.keys import FetchKeyResult +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +db_binary_type = memoryview + + +class KeyStore(SQLBaseStore): + """Persistence for signature verification keys + """ + + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ + Args: + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + + Returns: + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is + unknown + """ + keys = {} + + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" + + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) + + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) + + for row in txn: + server_name, key_id, key_bytes, ts_valid_until_ms = row + + if ts_valid_until_ms is None: + # Old keys may be stored with a ts_valid_until_ms of null, + # in which case we treat this as if it was set to `0`, i.e. + # it won't match key requests that define a minimum + # `ts_valid_until_ms`. + ts_valid_until_ms = 0 + + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, + ) + keys[(server_name, key_id)] = res + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.db.runInteraction("get_server_verify_keys", _txn) + + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. + Args: + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). + """ + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, fetch_result in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i,)) + return res + + return self.db.runInteraction( + "store_server_verify_keys", + self.db.simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) + + def store_server_keys_json( + self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes + ): + """Stores the JSON bytes for a set of keys from a server + The JSON should be signed by the originating server, the intermediate + server, and by this server. Updates the value for the + (server_name, key_id, from_server) triplet if one already existed. + Args: + server_name (str): The name of the server. + key_id (str): The identifer of the key this JSON is for. + from_server (str): The server this JSON was fetched from. + ts_now_ms (int): The time now in milliseconds. + ts_valid_until_ms (int): The time when this json stops being valid. + key_json (bytes): The encoded JSON. + """ + return self.db.simple_upsert( + table="server_keys_json", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + }, + values={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + "ts_added_ms": ts_now_ms, + "ts_valid_until_ms": ts_expires_ms, + "key_json": db_binary_type(key_json_bytes), + }, + desc="store_server_keys_json", + ) + + def get_server_keys_json(self, server_keys): + """Retrive the key json for a list of server_keys and key ids. + If no keys are found for a given server, key_id and source then + that server, key_id, and source triplet entry will be an empty list. + The JSON is returned as a byte array so that it can be efficiently + used in an HTTP response. + Args: + server_keys (list): List of (server_name, key_id, source) triplets. + Returns: + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts + """ + + def _get_server_keys_json_txn(txn): + results = {} + for server_name, key_id, from_server in server_keys: + keyvalues = {"server_name": server_name} + if key_id is not None: + keyvalues["key_id"] = key_id + if from_server is not None: + keyvalues["from_server"] = from_server + rows = self.db.simple_select_list_txn( + txn, + "server_keys_json", + keyvalues=keyvalues, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + ) + results[(server_name, key_id, from_server)] = rows + return results + + return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 6b1238ce4a..8aecd414c2 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -12,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database -class MediaRepositoryStore(BackgroundUpdateStore): - """Persistence for attachments and avatars""" - - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) - self.register_background_index_update( + self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", @@ -29,12 +30,19 @@ class MediaRepositoryStore(BackgroundUpdateStore): where_clause="url_cache IS NOT NULL", ) + +class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): + """Persistence for attachments and avatars""" + + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) + def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: None if the media_id doesn't exist. """ - return self._simple_select_one( + return self.db.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -59,7 +67,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): user_id, url_cache=None, ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository", { "media_id": media_id, @@ -119,12 +127,12 @@ class MediaRepositoryStore(BackgroundUpdateStore): ) ) - return self.runInteraction("get_url_cache", get_url_cache_txn) + return self.db.runInteraction("get_url_cache", get_url_cache_txn) def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -139,7 +147,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): ) def get_local_media_thumbnails(self, media_id): - return self._simple_select_list( + return self.db.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -161,7 +169,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -175,7 +183,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): ) def get_cached_remote_media(self, origin, media_id): - return self._simple_select_one( + return self.db.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -200,7 +208,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): upload_name, filesystem_id, ): - return self._simple_insert( + return self.db.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -245,10 +253,12 @@ class MediaRepositoryStore(BackgroundUpdateStore): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.runInteraction("update_cached_last_access_time", update_cache_txn) + return self.db.runInteraction( + "update_cached_last_access_time", update_cache_txn + ) def get_remote_media_thumbnails(self, origin, media_id): - return self._simple_select_list( + return self.db.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -273,7 +283,7 @@ class MediaRepositoryStore(BackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.db.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -295,24 +305,24 @@ class MediaRepositoryStore(BackgroundUpdateStore): " WHERE last_access_ts < ?" ) - return self._execute( - "get_remote_media_before", self.cursor_to_dict, sql, before_ts + return self.db.execute( + "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.runInteraction("delete_remote_media", delete_remote_media_txn) + return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) def get_expired_url_cache(self, now_ts): sql = ( @@ -326,18 +336,20 @@ class MediaRepositoryStore(BackgroundUpdateStore): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + return self.db.runInteraction( + "get_expired_url_cache", _get_expired_url_cache_txn + ) - def delete_url_cache(self, media_ids): + async def delete_url_cache(self, media_ids): if len(media_ids) == 0: return - sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?" + sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( @@ -351,23 +363,23 @@ class MediaRepositoryStore(BackgroundUpdateStore): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.runInteraction( + return self.db.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) - def delete_url_cache_media(self, media_ids): + async def delete_url_cache_media(self, media_ids): if len(media_ids) == 0: return def _delete_url_cache_media_txn(txn): - sql = "DELETE FROM local_media_repository" " WHERE media_id = ?" + sql = "DELETE FROM local_media_repository WHERE media_id = ?" txn.executemany(sql, [(media_id,) for media_id in media_ids]) - sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?" + sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction( + return await self.db.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py new file mode 100644 index 0000000000..dad5bbc602 --- /dev/null +++ b/synapse/storage/data_stores/main/metrics.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing +from collections import Counter + +from twisted.internet import defer + +from synapse.metrics import BucketCollector +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.database import Database + + +class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): + """Functions to pull various metrics from the DB, for e.g. phone home + stats and prometheus metrics. + """ + + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + # Collect metrics on the number of forward extremities that exist. + # Counter of number of extremities to count + self._current_forward_extremities_amount = ( + Counter() + ) # type: typing.Counter[int] + + BucketCollector( + "synapse_forward_extremities", + lambda: self._current_forward_extremities_amount, + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], + ) + + # Read the extrems every 60 minutes + def read_forward_extremities(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "read_forward_extremities", self._read_forward_extremities + ) + + hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + + async def _read_forward_extremities(self): + def fetch(txn): + txn.execute( + """ + select count(*) c from event_forward_extremities + group by room_id + """ + ) + return txn.fetchall() + + res = await self.db.runInteraction("read_forward_extremities", fetch) + self._current_forward_extremities_amount = Counter([x[0] for x in res]) + + @defer.inlineCallbacks + def count_daily_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_sent_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_active_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_daily_active_rooms", _count) + return ret diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 752e9788a2..e459cf49a0 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database, make_in_list_sql_clause from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) # Number of msec of granularity to store the monthly_active_user timestamp @@ -27,20 +28,113 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 60 * 60 * 1000 -class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) +class MonthlyActiveUsersWorkerStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs - self.reserved_users = () + + @cached(num_args=0) + def get_monthly_active_count(self): + """Generates current count of monthly active users + + Returns: + Defered[int]: Number of current monthly active users + """ + + def _count_users(txn): + sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + txn.execute(sql) + (count,) = txn.fetchone() + return count + + return self.db.runInteraction("count_users", _count_users) + + @cached(num_args=0) + def get_monthly_active_count_by_service(self): + """Generates current count of monthly active users broken down by service. + A service is typically an appservice but also includes native matrix users. + Since the `monthly_active_users` table is populated from the `user_ips` table + `config.track_appservice_user_ips` must be set to `true` for this + method to return anything other than native matrix users. + + Returns: + Deferred[dict]: dict that includes a mapping between app_service_id + and the number of occurrences. + + """ + + def _count_users_by_service(txn): + sql = """ + SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + GROUP BY appservice_id; + """ + + txn.execute(sql) + result = txn.fetchall() + return dict(result) + + return self.db.runInteraction("count_users_by_service", _count_users_by_service) + + async def get_registered_reserved_users(self) -> List[str]: + """Of the reserved threepids defined in config, retrieve those that are associated + with registered users + + Returns: + User IDs of actual users that are reserved + """ + users = [] + + for tp in self.hs.config.mau_limits_reserved_threepids[ + : self.hs.config.max_mau_value + ]: + user_id = await self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + users.append(user_id) + + return users + + @cached(num_args=1) + def user_last_seen_monthly_active(self, user_id): + """ + Checks if a given user is part of the monthly active user group + Arguments: + user_id (str): user to add/update + Return: + Deferred[int] : timestamp since last seen, None if never seen + + """ + + return self.db.simple_select_one_onecol( + table="monthly_active_users", + keyvalues={"user_id": user_id}, + retcol="timestamp", + allow_none=True, + desc="user_last_seen_monthly_active", + ) + + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_stats_only = hs.config.mau_stats_only + self._max_mau_value = hs.config.max_mau_value + # Do not add more reserved users than the total allowable number - self._new_transaction( - dbconn, + # cur = LoggingTransaction( + self.db.new_transaction( + db_conn, "initialise_mau_threepids", [], [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value], + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): @@ -51,7 +145,15 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn (cursor): threepids (list[dict]): List of threepid dicts to reserve """ - reserved_user_list = [] + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? for tp in threepids: user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) @@ -59,121 +161,94 @@ class MonthlyActiveUsersStore(SQLBaseStore): if user_id: is_support = self.is_support_user_txn(txn, user_id) if not is_support: - self.upsert_monthly_active_user_txn(txn, user_id) - reserved_user_list.append(user_id) + # We do this manually here to avoid hitting #6791 + self.db.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) else: logger.warning("mau limit reserved threepid %s not found in db" % tp) - self.reserved_users = tuple(reserved_user_list) - @defer.inlineCallbacks - def reap_monthly_active_users(self): + async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. - - Returns: - Deferred[] """ - def _reap_users(txn): - # Purge stale users + def _reap_users(txn, reserved_users): + """ + Args: + reserved_users (tuple): reserved users to preserve + """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - query_args = [thirty_days_ago] - base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - - # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres - # when len(reserved_users) == 0. Works fine on sqlite. - if len(self.reserved_users) > 0: - # questionmarks is a hack to overcome sqlite not supporting - # tuples in 'WHERE IN %s' - questionmarks = "?" * len(self.reserved_users) - - query_args.extend(self.reserved_users) - sql = base_sql + """ AND user_id NOT IN ({})""".format( - ",".join(questionmarks) - ) - else: - sql = base_sql - txn.execute(sql, query_args) + in_clause, in_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", reserved_users + ) + + txn.execute( + "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s" + % (in_clause,), + [thirty_days_ago] + in_clause_args, + ) - if self.hs.config.limit_usage_by_mau: + if self._limit_usage_by_mau: # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. # Note it is not possible to write this query using OFFSET due to # incompatibilities in how sqlite and postgres support the feature. - # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present - # While Postgres does not require 'LIMIT', but also does not support + # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present, + # while Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support - safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) - # Must be greater than zero for postgres - safe_guard = safe_guard if safe_guard > 0 else 0 - query_args = [safe_guard] - base_sql = """ + # Limit must be >= 0 for postgres + num_of_non_reserved_users_to_remove = max( + self._max_mau_value - len(reserved_users), 0 + ) + + # It is important to filter reserved users twice to guard + # against the case where the reserved user is present in the + # SELECT, meaning that a legitimate mau is deleted. + sql = """ DELETE FROM monthly_active_users WHERE user_id NOT IN ( SELECT user_id FROM monthly_active_users + WHERE NOT %s ORDER BY timestamp DESC LIMIT ? - ) - """ - # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres - # when len(reserved_users) == 0. Works fine on sqlite. - if len(self.reserved_users) > 0: - query_args.extend(self.reserved_users) - sql = base_sql + """ AND user_id NOT IN ({})""".format( - ",".join(questionmarks) ) - else: - sql = base_sql - txn.execute(sql, query_args) - - yield self.runInteraction("reap_monthly_active_users", _reap_users) - # It seems poor to invalidate the whole cache, Postgres supports - # 'Returning' which would allow me to invalidate only the - # specific users, but sqlite has no way to do this and instead - # I would need to SELECT and the DELETE which without locking - # is racy. - # Have resolved to invalidate the whole cache for now and do - # something about it if and when the perf becomes significant - self.user_last_seen_monthly_active.invalidate_all() - self.get_monthly_active_count.invalidate_all() - - @cached(num_args=0) - def get_monthly_active_count(self): - """Generates current count of monthly active users - - Returns: - Defered[int]: Number of current monthly active users - """ - - def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - - txn.execute(sql) - count, = txn.fetchone() - return count - - return self.runInteraction("count_users", _count_users) + AND NOT %s + """ % ( + in_clause, + in_clause, + ) - @defer.inlineCallbacks - def get_registered_reserved_users_count(self): - """Of the reserved threepids defined in config, how many are associated - with registered users? + query_args = ( + in_clause_args + + [num_of_non_reserved_users_to_remove] + + in_clause_args + ) + txn.execute(sql, query_args) - Returns: - Defered[int]: Number of real reserved users - """ - count = 0 - for tp in self.hs.config.mau_limits_reserved_threepids: - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - tp["medium"], tp["address"] + # It seems poor to invalidate the whole cache. Postgres supports + # 'Returning' which would allow me to invalidate only the + # specific users, but sqlite has no way to do this and instead + # I would need to SELECT and the DELETE which without locking + # is racy. + # Have resolved to invalidate the whole cache for now and do + # something about it if and when the perf becomes significant + self._invalidate_all_cache_and_stream( + txn, self.user_last_seen_monthly_active ) - if user_id: - count = count + 1 - return count + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + + reserved_users = await self.get_registered_reserved_users() + await self.db.runInteraction( + "reap_monthly_active_users", _reap_users, reserved_users + ) @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): @@ -182,6 +257,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): Args: user_id (str): user to add/update + + Returns: + Deferred """ # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -195,27 +273,13 @@ class MonthlyActiveUsersStore(SQLBaseStore): if is_support: return - yield self.runInteraction( + yield self.db.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) - user_in_mau = self.user_last_seen_monthly_active.cache.get( - (user_id,), None, update_metrics=False - ) - if user_in_mau is None: - self.get_monthly_active_count.invalidate(()) - - self.user_last_seen_monthly_active.invalidate((user_id,)) - def upsert_monthly_active_user_txn(self, txn, user_id): """Updates or inserts monthly active user member - Note that, after calling this method, it will generally be necessary - to invalidate the caches on user_last_seen_monthly_active and - get_monthly_active_count. We can't do that here, because we are running - in a database thread rather than the main thread, and we can't call - txn.call_after because txn may not be a LoggingTransaction. - We consciously do not call is_support_txn from this method because it is not possible to cache the response. is_support_txn will be false in almost all cases, so it seems reasonable to call it only for @@ -239,33 +303,22 @@ class MonthlyActiveUsersStore(SQLBaseStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self._simple_upsert_txn( + is_insert = self.db.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, values={"timestamp": int(self._clock.time_msec())}, ) - return is_insert - - @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): - """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen - - """ - - return self._simple_select_one_onecol( - table="monthly_active_users", - keyvalues={"user_id": user_id}, - retcol="timestamp", - allow_none=True, - desc="user_last_seen_monthly_active", + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream( + txn, self.get_monthly_active_count_by_service, () ) + self._invalidate_cache_and_stream( + txn, self.user_last_seen_monthly_active, (user_id,) + ) + + return is_insert @defer.inlineCallbacks def populate_monthly_active_users(self, user_id): @@ -275,7 +328,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): Args: user_id(str): the user_id to query """ - if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only: + if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group is_guest = yield self.is_guest(user_id) if is_guest: @@ -296,11 +349,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): # In the case where mau_stats_only is True and limit_usage_by_mau is # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. - if not self.hs.config.limit_usage_by_mau: + if not self._limit_usage_by_mau: yield self.upsert_monthly_active_user(user_id) else: count = yield self.get_monthly_active_count() - if count < self.hs.config.max_mau_value: + if count < self._max_mau_value: yield self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: yield self.upsert_monthly_active_user(user_id) diff --git a/synapse/storage/openid.py b/synapse/storage/data_stores/main/openid.py index b3318045ee..cc21437e92 100644 --- a/synapse/storage/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -1,9 +1,9 @@ -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self._simple_insert( + return self.db.simple_insert( table="open_id_tokens", values={ "token": token, @@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) + return self.db.runInteraction( + "get_user_id_for_token", get_user_id_for_token_txn + ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py new file mode 100644 index 0000000000..dab31e0c2d --- /dev/null +++ b/synapse/storage/data_stores/main/presence.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.presence import UserPresenceState +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter + + +class PresenceStore(SQLBaseStore): + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) + ) + + with stream_ordering_manager as stream_orderings: + yield self.db.runInteraction( + "update_presence", + self._update_presence_txn, + stream_orderings, + presence_states, + ) + + return stream_orderings[-1], self._presence_id_gen.get_current_token() + + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): + txn.call_after( + self.presence_stream_cache.entity_has_changed, state.user_id, stream_id + ) + txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) + + # Actually insert new rows + self.db.simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for stream_id, state in zip(stream_orderings, presence_states) + ], + ) + + # Delete old rows to stop database from getting really big + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " + + for states in batch_iter(presence_states, 50): + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) + + def get_all_presence_updates(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_presence_updates_txn(txn): + sql = """ + SELECT stream_id, user_id, state, last_active_ts, + last_federation_update_ts, last_user_sync_ts, + status_msg, + currently_active + FROM presence_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) + + @cached() + def _get_presence_for_user(self, user_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_presence_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def get_presence_for_users(self, user_ids): + rows = yield self.db.simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + desc="get_presence_for_users", + ) + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return {row["user_id"]: UserPresenceState(**row) for row in rows} + + def get_current_presence_token(self): + return self._presence_id_gen.get_current_token() + + def allow_presence_visible(self, observed_localpart, observer_userid): + return self.db.simple_insert( + table="presence_allow_inbound", + values={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="allow_presence_visible", + or_ignore=True, + ) + + def disallow_presence_visible(self, observed_localpart, observer_userid): + return self.db.simple_delete_one( + table="presence_allow_inbound", + keyvalues={ + "observed_user_id": observed_localpart, + "observer_user_id": observer_userid, + }, + desc="disallow_presence_visible", + ) diff --git a/synapse/storage/profile.py b/synapse/storage/data_stores/main/profile.py index 912c1df6be..bfc9369f0b 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -16,16 +16,15 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.storage.roommember import ProfileInfo - -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self._simple_select_one( + profile = yield self.db.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -43,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -51,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -59,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_from_remote_profile_cache(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -68,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): ) def create_profile(self, user_localpart): - return self._simple_insert( + return self.db.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -81,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -96,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self._simple_upsert( + return self.db.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -108,10 +107,10 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self._simple_update( + return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, - values={ + updatevalues={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), @@ -126,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self._simple_delete( + yield self.db.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -145,9 +144,9 @@ class ProfileStore(ProfileWorkerStore): txn.execute(sql, (last_checked,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) @@ -156,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -167,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py new file mode 100644 index 0000000000..a93e1ef198 --- /dev/null +++ b/synapse/storage/data_stores/main/purge_events.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Tuple + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.types import RoomStreamToken + +logger = logging.getLogger(__name__) + + +class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + + Returns: + Deferred[set[int]]: The set of state groups that are referenced by + deleted events. + """ + + return self.db.runInteraction( + "purge_history", + self._purge_history_txn, + room_id, + token, + delete_local_events, + ) + + def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): + token = RoomStreamToken.parse(token_str) + + # Tables that should be pruned: + # event_auth + # event_backward_extremities + # event_edges + # event_forward_extremities + # event_json + # event_push_actions + # event_reference_hashes + # event_search + # event_to_state_groups + # events + # rejections + # room_depth + # state_groups + # state_groups_state + + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # First ensure that we're not about to delete all the forward extremeties + txn.execute( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "AND e.room_id = f.room_id " + "WHERE f.room_id = ?", + (room_id,), + ) + rows = txn.fetchall() + max_depth = max(row[1] for row in rows) + + if max_depth < token.topological: + # We need to ensure we don't delete all the events from the database + # otherwise we wouldn't be able to send any events (due to not + # having any backwards extremeties) + raise SynapseError( + 400, "topological_ordering is greater than forward extremeties" + ) + + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () # type: Tuple[Any, ...] + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + + # We include the parameter twice since we use the expression twice + should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) + + should_delete_params += (room_id, token.topological) + + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. + txn.execute( + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % (should_delete_expr, should_delete_expr), + should_delete_params, + ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)" + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") + + txn.execute("SELECT event_id, should_delete FROM events_to_purge") + event_rows = txn.fetchall() + logger.info( + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), + sum(1 for e in event_rows if e[1]), + ) + + logger.info("[purge] Finding new backward extremities") + + # We calculate the new entries for the backward extremeties by finding + # events to be purged that are pointed to by events we're not going to + # purge. + txn.execute( + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" + " WHERE ep2.event_id IS NULL" + ) + new_backwards_extrems = txn.fetchall() + + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + + txn.execute( + "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) + ) + + # Update backward extremeties + txn.executemany( + "INSERT INTO event_backward_extremities (room_id, event_id)" + " VALUES (?, ?)", + [(room_id, event_id) for event_id, in new_backwards_extrems], + ) + + logger.info("[purge] finding state groups referenced by deleted events") + + # Get all state groups that are referenced by events that are to be + # deleted. + txn.execute( + """ + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) + """ + ) + + referenced_state_groups = {sg for sg, in txn} + logger.info( + "[purge] found %i referenced state groups", len(referenced_state_groups) + ) + + logger.info("[purge] removing events from event_to_state_groups") + txn.execute( + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" + ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + + # Delete all remote non-state events + for table in ( + "events", + "event_json", + "event_auth", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "rejections", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,) + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ("event_push_actions",): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id,), + ) + + # Mark all state and own events as outliers + logger.info("[purge] marking remaining events as outliers") + txn.execute( + "UPDATE events SET outlier = ?" + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + # + # We do this by calculating the minimum depth of the backwards + # extremities. However, the events in event_backward_extremities + # are ones we don't have yet so we need to look at the events that + # point to it via event_edges table. + txn.execute( + """ + SELECT COALESCE(MIN(depth), 0) + FROM event_backward_extremities AS eb + INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id + INNER JOIN events AS e ON e.event_id = eg.event_id + WHERE eb.room_id = ? + """, + (room_id,), + ) + (min_depth,) = txn.fetchone() + + logger.info("[purge] updating room_depth to %d", min_depth) + + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (min_depth, room_id), + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute("DROP TABLE events_to_purge") + + logger.info("[purge] done") + + return referenced_state_groups + + def purge_room(self, room_id): + """Deletes all record of a room + + Args: + room_id (str) + + Returns: + Deferred[List[int]]: The list of state groups to delete. + """ + + return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) + + def _purge_room_txn(self, txn, room_id): + # First we fetch all the state groups that should be deleted, before + # we delete that information. + txn.execute( + """ + SELECT DISTINCT state_group FROM events + INNER JOIN event_to_state_groups USING(event_id) + WHERE events.room_id = ? + """, + (room_id,), + ) + + state_groups = [row[0] for row in txn] + + # Now we delete tables which lack an index on room_id but have one on event_id + for table in ( + "event_auth", + "event_edges", + "event_push_actions_staging", + "event_reference_hashes", + "event_relations", + "event_to_state_groups", + "redactions", + "rejections", + "state_events", + ): + logger.info("[purge] removing %s from %s", room_id, table) + + txn.execute( + """ + DELETE FROM %s WHERE event_id IN ( + SELECT event_id FROM events WHERE room_id=? + ) + """ + % (table,), + (room_id,), + ) + + # and finally, the tables with an index on room_id (or no useful index) + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + # no useful index, but let's clear them anyway + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + "local_current_membership", + ): + logger.info("[purge] removing %s from %s", room_id, table) + txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) + + # Other tables we do NOT need to clear out: + # + # - blocked_rooms + # This is important, to make sure that we don't accidentally rejoin a blocked + # room after it was purged + # + # - user_directory + # This has a room_id column, but it is unused + # + + # Other tables that we might want to consider clearing out include: + # + # - event_reports + # Given that these are intended for abuse management my initial + # inclination is to leave them in place. + # + # - current_state_delta_stream + # - ex_outlier_stream + # - room_tags_revisions + # The problem with these is that they are largeish and there is no room_id + # index on them. In any case we should be clearing out 'stream' tables + # periodically anyway (#5888) + + # TODO: we could probably usefully do a bunch of cache invalidation here + + logger.info("[purge] done") + + return state_groups diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py new file mode 100644 index 0000000000..ef8f40959f --- /dev/null +++ b/synapse/storage/data_stores/main/push_rule.py @@ -0,0 +1,729 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +from typing import Union + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.push.baserules import list_with_base_rules +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.storage.util.id_generators import ChainedIdGenerator +from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache + +logger = logging.getLogger(__name__) + + +def _load_rules(rawrules, enabled_map): + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + rule["default"] = False + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule["rule_id"] + if rule_id in enabled_map: + if rule.get("enabled", True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule["enabled"] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + +class PushRulesWorkerStore( + ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + EventsWorkerStore, + SQLBaseStore, +): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, database: Database, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + + if hs.config.worker.worker_app is None: + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + else: + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id" + ) + + push_rules_prefill, push_rules_id = self.db.get_cache_dict( + db_conn, + "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_for_user(self, user_id): + rows = yield self.db.simple_select_list( + table="push_rules", + keyvalues={"user_name": user_id}, + retcols=( + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ), + desc="get_push_rules_enabled_for_user", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + + rules = _load_rules(rows, enabled_map) + + return rules + + @cachedInlineCallbacks(max_entries=5000) + def get_push_rules_enabled_for_user(self, user_id): + results = yield self.db.simple_select_list( + table="push_rules_enable", + keyvalues={"user_name": user_id}, + retcols=("user_name", "rule_id", "enabled"), + desc="get_push_rules_enabled_for_user", + ) + return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} + + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + (count,) = txn.fetchone() + return bool(count) + + return self.db.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + + @cachedList( + cached_method_name="get_push_rules_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules(self, user_ids): + if not user_ids: + return {} + + results = {user_id: [] for user_id in user_ids} + + rows = yield self.db.simple_select_many_batch( + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=("*",), + desc="bulk_get_push_rules", + ) + + rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) + + for row in rows: + results.setdefault(row["user_name"], []).append(row) + + enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + + for user_id, rules in results.items(): + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + + return results + + @defer.inlineCallbacks + def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + """Copy a single push rule from one room to another for a specific user. + + Args: + new_room_id (str): ID of the new room. + user_id (str): ID of user the push rule belongs to. + rule (Dict): A push rule. + """ + # Create new rule id + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) + new_rule_id = rule_id_scope + "/" + new_room_id + + # Change room id in each condition + for condition in rule.get("conditions", []): + if condition.get("key") == "room_id": + condition["pattern"] = new_room_id + + # Add the rule for the new room + yield self.add_push_rule( + user_id=user_id, + rule_id=new_rule_id, + priority_class=rule["priority_class"], + conditions=rule["conditions"], + actions=rule["actions"], + ) + + @defer.inlineCallbacks + def copy_push_rules_from_room_to_room_for_user( + self, old_room_id, new_room_id, user_id + ): + """Copy all of the push rules from one room to another for a specific + user. + + Args: + old_room_id (str): ID of the old room. + new_room_id (str): ID of the new room. + user_id (str): ID of user to copy push rules for. + """ + # Retrieve push rules for this user + user_push_rules = yield self.get_push_rules_for_user(user_id) + + # Get rules relating to the old room and copy them to the new room + for rule in user_push_rules: + conditions = rule.get("conditions", []) + if any( + (c.get("key") == "room_id" and c.get("pattern") == old_room_id) + for c in conditions + ): + yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) + + @defer.inlineCallbacks + def bulk_get_push_rules_for_room(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield context.get_current_state_ids() + result = yield self._bulk_get_push_rules_for_room( + event.room_id, state_group, current_state_ids, event=event + ) + return result + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _bulk_get_push_rules_for_room( + self, room_id, state_group, current_state_ids, cache_context, event=None + ): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + # We also will want to generate notifs for other people in the room so + # their unread countss are correct in the event stream, but to avoid + # generating them for bot / AS users etc, we only do so for people who've + # sent a read receipt into the room. + + users_in_room = yield self._get_joined_users_from_context( + room_id, + state_group, + current_state_ids, + on_invalidate=cache_context.invalidate, + event=event, + ) + + # We ignore app service users for now. This is so that we don't fill + # up the `get_if_users_have_pushers` cache with AS entries that we + # know don't have pushers, nor even read receipts. + local_users_in_room = { + u + for u in users_in_room + if self.hs.is_mine_id(u) + and not self.get_if_app_services_interested_in_user(u) + } + + # users in the room who have pushers need to get push rules run because + # that's how their pushers work + if_users_with_pushers = yield self.get_if_users_have_pushers( + local_users_in_room, on_invalidate=cache_context.invalidate + ) + user_ids = { + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher + } + + users_with_receipts = yield self.get_users_with_read_receipts_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in local_users_in_room: + user_ids.add(uid) + + rules_by_user = yield self.bulk_get_push_rules( + user_ids, on_invalidate=cache_context.invalidate + ) + + rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} + + return rules_by_user + + @cachedList( + cached_method_name="get_push_rules_enabled_for_user", + list_name="user_ids", + num_args=1, + inlineCallbacks=True, + ) + def bulk_get_push_rules_enabled(self, user_ids): + if not user_ids: + return {} + + results = {user_id: {} for user_id in user_ids} + + rows = yield self.db.simple_select_many_batch( + table="push_rules_enable", + column="user_name", + iterable=user_ids, + retcols=("user_name", "rule_id", "enabled"), + desc="bulk_get_push_rules_enabled", + ) + for row in rows: + enabled = bool(row["enabled"]) + results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled + return results + + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + if last_id == current_id: + return defer.succeed([]) + + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, event_stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_push_rule_updates", get_all_push_rule_updates_txn + ) + + +class PushRuleStore(PushRulesWorkerStore): + @defer.inlineCallbacks + def add_push_rule( + self, + user_id, + rule_id, + priority_class, + conditions, + actions, + before=None, + after=None, + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + if before or after: + yield self.db.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ) + else: + yield self.db.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ) + + def _add_push_rule_relative_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + before, + after, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + relative_to_rule = before or after + + res = self.db.simple_select_one_txn( + txn, + table="push_rules", + keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, + retcols=["priority_class", "priority"], + allow_none=True, + ) + + if not res: + raise RuleNotFoundException( + "before/after rule not found: %s" % (relative_to_rule,) + ) + + base_priority_class = res["priority_class"] + base_rule_priority = res["priority"] + + if base_priority_class != priority_class: + raise InconsistentRuleException( + "Given priority class does not match class of relative rule" + ) + + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 + else: + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority + + sql = ( + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" + ) + + txn.execute(sql, (user_id, priority_class, new_rule_priority)) + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_rule_priority, + conditions_json, + actions_json, + ) + + def _add_push_rule_highest_priority_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + conditions_json, + actions_json, + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + + # find the highest priority rule in that class + sql = ( + "SELECT COUNT(*), MAX(priority) FROM push_rules" + " WHERE user_name = ? and priority_class = ?" + ) + txn.execute(sql, (user_id, priority_class)) + res = txn.fetchall() + (how_many, highest_prio) = res[0] + + new_prio = 0 + if how_many > 0: + new_prio = highest_prio + 1 + + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + new_prio, + conditions_json, + actions_json, + ) + + def _upsert_push_rule_txn( + self, + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + conditions_json, + actions_json, + update_stream=True, + ): + """Specialised version of simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" + ) + + txn.execute( + sql, + (priority_class, priority, conditions_json, actions_json, user_id, rule_id), + ) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next() + + self.db.simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + if update_stream: + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ADD", + data={ + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + @defer.inlineCallbacks + def delete_push_rule(self, user_id, rule_id): + """ + Delete a push rule. Args specify the row to be deleted and can be + any of the columns in the push_rule table, but below are the + standard ones + + Args: + user_id (str): The matrix ID of the push rule owner + rule_id (str): The rule_id of the rule to be deleted + """ + + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + self.db.simple_delete_one_txn( + txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} + ) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.db.runInteraction( + "delete_push_rule", + delete_push_rule_txn, + stream_id, + event_stream_ordering, + ) + + @defer.inlineCallbacks + def set_push_rule_enabled(self, user_id, rule_id, enabled): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.db.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + enabled, + ) + + def _set_push_rule_enabled_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + ): + new_id = self._push_rules_enable_id_gen.get_next() + self.db.simple_upsert_txn( + txn, + "push_rules_enable", + {"user_name": user_id, "rule_id": rule_id}, + {"enabled": 1 if enabled else 0}, + {"id": new_id}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ENABLE" if enabled else "DISABLE", + ) + + @defer.inlineCallbacks + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + priority_class, + priority, + "[]", + actions_json, + update_stream=False, + ) + else: + self.db.simple_update_one_txn( + txn, + "push_rules", + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, + ) + + self._insert_push_rules_update_txn( + txn, + stream_id, + event_stream_ordering, + user_id, + rule_id, + op="ACTIONS", + data={"actions": actions_json}, + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.db.runInteraction( + "set_push_rule_actions", + set_push_rule_actions_txn, + stream_id, + event_stream_ordering, + ) + + def _insert_push_rules_update_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "event_stream_ordering": event_stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self.db.simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) + txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_current_token() + + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] diff --git a/synapse/storage/pusher.py b/synapse/storage/data_stores/main/pusher.py index 3e0e834a62..547b9d69cb 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -15,52 +15,42 @@ # limitations under the License. import logging - -import six +from typing import Iterable, Iterator from canonicaljson import encode_canonical_json, json from twisted.internet import defer +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows): + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + """JSON-decode the data in the rows returned from the `pushers` table + + Drops any rows whose data cannot be decoded + """ for r in rows: dataJson = r["data"] - r["data"] = None try: - if isinstance(dataJson, db_binary_type): - dataJson = str(dataJson).decode("UTF8") - r["data"] = json.loads(dataJson) except Exception as e: - logger.warn( + logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], dataJson, e.args[0], ) - pass - - if isinstance(r["pushkey"], db_binary_type): - r["pushkey"] = str(r["pushkey"]).decode("UTF8") + continue - return rows + yield r @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self._simple_select_one_onecol( + ret = yield self.db.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -73,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self._simple_select_list( + ret = yield self.db.simple_select_list( "pushers", keyvalues, [ @@ -101,11 +91,11 @@ class PusherWorkerStore(SQLBaseStore): def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.runInteraction("get_all_pushers", get_pushers) + rows = yield self.db.runInteraction("get_all_pushers", get_pushers) return rows def get_all_updated_pushers(self, last_id, current_id, limit): @@ -135,7 +125,7 @@ class PusherWorkerStore(SQLBaseStore): return updated, deleted - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers", get_all_updated_pushers_txn ) @@ -178,7 +168,7 @@ class PusherWorkerStore(SQLBaseStore): return results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) @@ -194,7 +184,7 @@ class PusherWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -207,6 +197,84 @@ class PusherWorkerStore(SQLBaseStore): return result + @defer.inlineCallbacks + def update_pusher_last_stream_ordering( + self, app_id, pushkey, user_id, last_stream_ordering + ): + yield self.db.simple_update_one( + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, + desc="update_pusher_last_stream_ordering", + ) + + @defer.inlineCallbacks + def update_pusher_last_stream_ordering_and_success( + self, app_id, pushkey, user_id, last_stream_ordering, last_success + ): + """Update the last stream ordering position we've processed up to for + the given pusher. + + Args: + app_id (str) + pushkey (str) + last_stream_ordering (int) + last_success (int) + + Returns: + Deferred[bool]: True if the pusher still exists; False if it has been deleted. + """ + updated = yield self.db.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={ + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, + }, + desc="update_pusher_last_stream_ordering_and_success", + ) + + return bool(updated) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): + yield self.db.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={"failing_since": failing_since}, + desc="update_pusher_failing_since", + ) + + @defer.inlineCallbacks + def get_throttle_params_by_room(self, pusher_id): + res = yield self.db.simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room", + ) + + params_by_room = {} + for row in res: + params_by_room[row["room_id"]] = { + "last_sent_ts": row["last_sent_ts"], + "throttle_ms": row["throttle_ms"], + } + + return params_by_room + + @defer.inlineCallbacks + def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so simple_upsert will retry + yield self.db.simple_upsert( + "pusher_throttle", + {"pusher": pusher_id, "room_id": room_id}, + params, + desc="set_throttle_params", + lock=False, + ) + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self): @@ -230,8 +298,8 @@ class PusherStore(PusherWorkerStore): ): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so _simple_upsert will retry - yield self._simple_upsert( + # (app_id, pushkey, user_name) so simple_upsert will retry + yield self.db.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -241,7 +309,7 @@ class PusherStore(PusherWorkerStore): "device_display_name": device_display_name, "ts": pushkey_ts, "lang": lang, - "data": encode_canonical_json(data), + "data": bytearray(encode_canonical_json(data)), "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, @@ -256,7 +324,7 @@ class PusherStore(PusherWorkerStore): if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.runInteraction( + yield self.db.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, @@ -270,7 +338,7 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -279,7 +347,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -291,82 +359,4 @@ class PusherStore(PusherWorkerStore): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( - self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self._simple_update_one( - "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - {"last_stream_ordering": last_stream_ordering}, - desc="update_pusher_last_stream_ordering", - ) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): - """Update the last stream ordering position we've processed up to for - the given pusher. - - Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) - - Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. - """ - updated = yield self._simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={ - "last_stream_ordering": last_stream_ordering, - "last_success": last_success, - }, - desc="update_pusher_last_stream_ordering_and_success", - ) - - return bool(updated) - - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self._simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={"failing_since": failing_since}, - desc="update_pusher_failing_since", - ) - - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self._simple_select_list( - "pusher_throttle", - {"pusher": pusher_id}, - ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room", - ) - - params_by_room = {} - for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } - - return params_by_room - - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so _simple_upsert will retry - yield self._simple_upsert( - "pusher_throttle", - {"pusher": pusher_id, "room_id": room_id}, - params, - desc="set_throttle_params", - lock=False, - ) + yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) diff --git a/synapse/storage/receipts.py b/synapse/storage/data_stores/main/receipts.py index 290ddb30e8..cebdcd409f 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -21,12 +21,12 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import SQLBaseStore -from .util.id_generators import StreamIdGenerator - logger = logging.getLogger(__name__) @@ -39,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -58,11 +58,11 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - return set(r["user_id"] for r in receipts) + return {r["user_id"] for r in receipts} @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self._simple_select_list( + return self.db.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -71,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -85,7 +85,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -109,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) + rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], @@ -188,11 +188,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, to_key)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return rows - rows = yield self.runInteraction("get_linearized_receipts_for_room", f) + rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -217,28 +217,32 @@ class ReceiptsWorkerStore(SQLBaseStore): def f(txn): if from_key: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" - ) % (",".join(["?"] * len(room_ids))) - args = list(room_ids) - args.extend([from_key, to_key]) + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) - txn.execute(sql, args) + txn.execute(sql + clause, [from_key, to_key] + list(args)) else: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id IN (%s) AND stream_id <= ?" - ) % (",".join(["?"] * len(room_ids))) + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? AND + """ - args = list(room_ids) - args.append(to_key) + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) - txn.execute(sql, args) + txn.execute(sql + clause, [to_key] + list(args)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) + txn_results = yield self.db.runInteraction( + "_get_linearized_receipts_for_rooms", f + ) results = {} for row in txn_results: @@ -279,9 +283,9 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return (r[0:5] + (json.loads(r[5]),) for r in txn) + return [r[0:5] + (json.loads(r[5]),) for r in txn] - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) @@ -312,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -334,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self._simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -387,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self._simple_delete_txn( + self.db.simple_upsert_txn( txn, table="receipts_linearized", keyvalues={ @@ -395,19 +399,14 @@ class ReceiptsStore(ReceiptsWorkerStore): "receipt_type": receipt_type, "user_id": user_id, }, - ) - - self._simple_insert_txn( - txn, - table="receipts_linearized", values={ "stream_id": stream_id, - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, "event_id": event_id, "data": json.dumps(data), }, + # receipts_linearized has a unique constraint on + # (user_id, room_id, receipt_type), so no need to lock + lock=False, ) if receipt_type == "m.read" and stream_ordering is not None: @@ -433,26 +432,32 @@ class ReceiptsStore(ReceiptsWorkerStore): # we need to points in graph -> linearized form. # TODO: Make this better. def graph_to_linear(txn): - query = ( - "SELECT event_id WHERE room_id = ? AND stream_ordering IN (" - " SELECT max(stream_ordering) WHERE event_id IN (%s)" - ")" - ) % (",".join(["?"] * len(event_ids))) + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) - txn.execute(query, [room_id] + event_ids) + txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.runInteraction( + linearized_event_id = yield self.db.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - event_ts = yield self.runInteraction( + event_ts = yield self.db.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -481,7 +486,7 @@ class ReceiptsStore(ReceiptsWorkerStore): return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.runInteraction( + return self.db.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -507,7 +512,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -516,7 +521,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/registration.py b/synapse/storage/data_stores/main/registration.py index 109052fa41..9768981891 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -17,17 +17,18 @@ import logging import re +from typing import Optional from six import iterkeys -from six.moves import range from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, ThreepidValidationError +from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -37,26 +38,28 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @cached() def get_user_by_id(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ "name", "password_hash", "is_guest", + "admin", "consent_version", "consent_server_notice_sent", "appservice_id", "creation_ts", "user_type", + "deactivated", ], allow_none=True, desc="get_user_by_id", @@ -94,7 +97,7 @@ class RegistrationWorkerStore(SQLBaseStore): including the keys `name`, `is_guest`, `device_id`, `token_id`, `valid_until_ms`. """ - return self.runInteraction( + return self.db.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -109,7 +112,7 @@ class RegistrationWorkerStore(SQLBaseStore): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -137,7 +140,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def set_account_validity_for_user_txn(txn): - self._simple_update_txn( + self.db.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -151,7 +154,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.runInteraction( + yield self.db.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) @@ -167,7 +170,7 @@ class RegistrationWorkerStore(SQLBaseStore): Raises: StoreError: The provided token is already set for another user. """ - yield self._simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -184,7 +187,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -203,7 +206,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -229,9 +232,9 @@ class RegistrationWorkerStore(SQLBaseStore): ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - res = yield self.runInteraction( + res = yield self.db.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), @@ -250,7 +253,7 @@ class RegistrationWorkerStore(SQLBaseStore): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self._simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -265,14 +268,13 @@ class RegistrationWorkerStore(SQLBaseStore): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self._simple_delete_one( + yield self.db.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", ) - @defer.inlineCallbacks - def is_server_admin(self, user): + async def is_server_admin(self, user): """Determines if a user is an admin of this homeserver. Args: @@ -281,7 +283,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self._simple_select_one_onecol( + res = await self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -289,7 +291,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="is_server_admin", ) - return res if res else False + return bool(res) if res else False def set_server_admin(self, user, admin): """Sets whether a user is an admin of this homeserver. @@ -299,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self._simple_update_one( - table="users", - keyvalues={"name": user.to_string()}, - updatevalues={"admin": 1 if admin else 0}, - desc="set_server_admin", - ) + + def set_server_admin_txn(txn): + self.db.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( @@ -316,7 +322,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0] @@ -332,10 +338,12 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user 'user_type' is null or empty string """ - res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + res = yield self.db.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) return res - @cachedInlineCallbacks() + @cached() def is_support_user(self, user_id): """Determines if the user is of type UserTypes.SUPPORT @@ -345,13 +353,12 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user is of type UserTypes.SUPPORT """ - res = yield self.runInteraction( + return self.db.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) - return res def is_real_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -361,7 +368,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -376,13 +383,31 @@ class RegistrationWorkerStore(SQLBaseStore): """ def f(txn): - sql = ( - "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)" - ) + sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" txn.execute(sql, (user_id,)) return dict(txn) - return self.runInteraction("get_users_by_id_case_insensitive", f) + return self.db.runInteraction("get_users_by_id_case_insensitive", f) + + async def get_user_by_external_id( + self, auth_provider: str, external_id: str + ) -> str: + """Look up a user by their external auth id + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + + Returns: + str|None: the mxid of the user, or None if they are not known + """ + return await self.db.simple_select_one_onecol( + table="user_external_ids", + keyvalues={"auth_provider": auth_provider, "external_id": external_id}, + retcol="user_id", + allow_none=True, + desc="get_user_by_external_id", + ) @defer.inlineCallbacks def count_all_users(self): @@ -390,12 +415,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret def count_daily_user_type(self): @@ -427,7 +452,7 @@ class RegistrationWorkerStore(SQLBaseStore): results[row[0]] = row[1] return results - return self.runInteraction("count_daily_user_type", _count_daily_user_type) + return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) @defer.inlineCallbacks def count_nonbridged_users(self): @@ -438,10 +463,10 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE appservice_id IS NULL """ ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret @defer.inlineCallbacks @@ -450,12 +475,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_real_users", _count_users) + ret = yield self.db.runInteraction("count_real_users", _count_users) return ret @defer.inlineCallbacks @@ -463,49 +488,45 @@ class RegistrationWorkerStore(SQLBaseStore): """ Gets the localpart of the next generated user ID. - Generated user IDs are integers, and we aim for them to be as small as - we can. Unfortunately, it's possible some of them are already taken by - existing users, and there may be gaps in the already taken range. This - function returns the start of the first allocatable gap. This is to - avoid the case of ID 10000000 being pre-allocated, so us wasting the - first (and shortest) many generated user IDs. + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that plus one. """ def _find_next_generated_user_id(txn): - txn.execute("SELECT name FROM users") + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") regex = re.compile(r"^@(\d+):") - found = set() + max_found = 0 for (user_id,) in txn: match = regex.search(user_id) if match: - found.add(int(match.group(1))) - for i in range(len(found) + 1): - if i not in found: - return i + max_found = max(int(match.group(1)), max_found) + + return max_found + 1 return ( ( - yield self.runInteraction( + yield self.db.runInteraction( "find_next_generated_user_id", _find_next_generated_user_id ) ) ) - @defer.inlineCallbacks - def get_user_id_by_threepid(self, medium, address, require_verified=False): + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: """Returns user id from threepid Args: - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - Deferred[str|None]: user id or None if no user id/threepid mapping exists + The user ID or None if no user id/threepid mapping exists """ - user_id = yield self.runInteraction( + user_id = await self.db.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -521,7 +542,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self._simple_select_one_txn( + ret = self.db.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -534,7 +555,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert( + yield self.db.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -542,7 +563,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self._simple_select_list( + ret = yield self.db.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -551,9 +572,22 @@ class RegistrationWorkerStore(SQLBaseStore): return ret def user_delete_threepid(self, user_id, medium, address): - return self._simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, + desc="user_delete_threepid", + ) + + def user_delete_threepids(self, user_id: str): + """Delete all threepid this user has bound + + Args: + user_id: The user id to delete all threepids of + + """ + return self.db.simple_delete( + "user_threepids", + keyvalues={"user_id": user_id}, desc="user_delete_threepids", ) @@ -573,7 +607,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ # We need to use an upsert, in case they user had already bound the # threepid - return self._simple_upsert( + return self.db.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -586,6 +620,26 @@ class RegistrationWorkerStore(SQLBaseStore): desc="add_user_bound_threepid", ) + def user_get_bound_threepids(self, user_id): + """Get the threepids that a user has bound to an identity server through the homeserver + The homeserver remembers where binds to an identity server occurred. Using this + method can retrieve those threepids. + + Args: + user_id (str): The ID of the user to retrieve threepids for + + Returns: + Deferred[list[dict]]: List of dictionaries containing the following: + medium (str): The medium of the threepid (e.g "email") + address (str): The address of the threepid (e.g "bob@example.com") + """ + return self.db.simple_select_list( + table="user_threepid_id_server", + keyvalues={"user_id": user_id}, + retcols=["medium", "address"], + desc="user_get_bound_threepids", + ) + def remove_user_bound_threepid(self, user_id, medium, address, id_server): """The server proxied an unbind request to the given identity server on behalf of the given user, so we remove the mapping of threepid to @@ -600,7 +654,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred """ - return self._simple_delete( + return self.db.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -623,7 +677,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -641,7 +695,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.Deferred(bool): The requested value. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -655,24 +709,37 @@ class RegistrationWorkerStore(SQLBaseStore): self, medium, client_secret, address=None, sid=None, validated=True ): """Gets a session_id and last_send_attempt (if available) for a - client_secret/medium/(address|session_id) combo + combination of validation metadata Args: medium (str|None): The medium of the 3PID address (str|None): The address of the 3PID sid (str|None): The ID of the validation session - client_secret (str|None): A unique string provided by the client to - help identify this validation attempt + client_secret (str): A unique string provided by the client to help identify this + validation attempt validated (bool|None): Whether sessions should be filtered by whether they have been validated already or not. None to perform no filtering Returns: - deferred {str, int}|None: A dict containing the - latest session_id and send_attempt count for this 3PID. - Otherwise None if there hasn't been a previous attempt + Deferred[dict|None]: A dict containing the following: + * address - address of the 3pid + * medium - medium of the 3pid + * client_secret - a secret provided by the client for this validation session + * session_id - ID of the validation session + * send_attempt - a number serving to dedupe send attempts for this session + * validated_at - timestamp of when this session was validated if so + + Otherwise None if a validation session is not found """ - keyvalues = {"medium": medium, "client_secret": client_secret} + if not client_secret: + raise SynapseError( + 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM + ) + + keyvalues = {"client_secret": client_secret} + if medium: + keyvalues["medium"] = medium if address: keyvalues["address"] = address if sid: @@ -695,13 +762,13 @@ class RegistrationWorkerStore(SQLBaseStore): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None return rows[0] - return self.runInteraction( + return self.db.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) @@ -715,70 +782,56 @@ class RegistrationWorkerStore(SQLBaseStore): """ def delete_threepid_session_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) -class RegistrationStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore -): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): + def __init__(self, database: Database, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() + self.config = hs.config - self.register_background_index_update( + self.db.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", columns=["creation_ts"], ) - self._account_validity = hs.config.account_validity - # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.register_noop_background_update("refresh_tokens_device_index") + self.db.updates.register_noop_background_update("refresh_tokens_device_index") - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 @@ -808,10 +861,10 @@ class RegistrationStore( (last_user, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: - return True + return True, 0 rows_processed_nb = 0 @@ -822,23 +875,79 @@ class RegistrationStore( logger.info("Marked %d rows as deactivated", rows_processed_nb) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) if batch_size > len(rows): - return True + return True, len(rows) else: - return False + return False, len(rows) - end = yield self.runInteraction( + end, nb_processed = yield self.db.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self._end_background_update("users_set_deactivated_flag") + yield self.db.updates._end_background_update("users_set_deactivated_flag") - return batch_size + return nb_processed + + @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.db.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn + ) + + yield self.db.updates._end_background_update("user_threepids_grandfather") + + return 1 + + +class RegistrationStore(RegistrationBackgroundUpdateStore): + def __init__(self, database: Database, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) + + self._account_validity = hs.config.account_validity + + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + def start_cull(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "cull_expired_threepid_validation_tokens", + self.cull_expired_threepid_validation_tokens, + ) + + hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): @@ -856,7 +965,7 @@ class RegistrationStore( """ next_id = self._access_tokens_id_gen.get_next() - yield self._simple_insert( + yield self.db.simple_insert( "access_tokens", { "id": next_id, @@ -883,7 +992,7 @@ class RegistrationStore( Args: user_id (str): The desired user ID to register. - password_hash (str): Optional. The password hash for this user. + password_hash (str|None): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, @@ -897,8 +1006,11 @@ class RegistrationStore( Raises: StoreError if the user_id could not be registered. + + Returns: + Deferred """ - return self.runInteraction( + return self.db.runInteraction( "register_user", self._register_user, user_id, @@ -932,7 +1044,7 @@ class RegistrationStore( # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self._simple_select_one_txn( + self.db.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -940,7 +1052,7 @@ class RegistrationStore( allow_none=False, ) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -954,7 +1066,7 @@ class RegistrationStore( }, ) else: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "users", values={ @@ -999,6 +1111,26 @@ class RegistrationStore( self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> Deferred: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + return self.db.simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + def user_set_password_hash(self, user_id, password_hash): """ NB. This does *not* evict any cache because the one use for this @@ -1007,12 +1139,14 @@ class RegistrationStore( """ def user_set_password_hash_txn(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) + return self.db.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) def user_set_consent_version(self, user_id, consent_version): """Updates the user table to record privacy policy consent @@ -1027,7 +1161,7 @@ class RegistrationStore( """ def f(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1035,7 +1169,7 @@ class RegistrationStore( ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_version", f) + return self.db.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): """Updates the user table to record that we have sent the user a server @@ -1051,7 +1185,7 @@ class RegistrationStore( """ def f(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1059,7 +1193,7 @@ class RegistrationStore( ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_server_notice_sent", f) + return self.db.runInteraction("user_set_consent_server_notice_sent", f) def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ @@ -1105,11 +1239,11 @@ class RegistrationStore( return tokens_and_devices - return self.runInteraction("user_delete_access_tokens", f) + return self.db.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1117,11 +1251,11 @@ class RegistrationStore( txn, self.get_user_by_access_token, (access_token,) ) - return self.runInteraction("delete_access_token", f) + return self.db.runInteraction("delete_access_token", f) @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1136,7 +1270,7 @@ class RegistrationStore( Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self._simple_insert( + return self.db.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1149,7 +1283,7 @@ class RegistrationStore( """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self._simple_delete( + return self.db.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1160,7 +1294,7 @@ class RegistrationStore( Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1168,36 +1302,6 @@ class RegistrationStore( desc="get_users_pending_deactivation", ) - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): - """We now track which identity servers a user binds their 3PID to, so - we need to handle the case of existing bindings where we didn't track - this. - - We do this by grandfathering in existing user threepids assuming that - they used one of the server configured trusted identity servers. - """ - id_servers = set(self.config.trusted_third_party_id_servers) - - def _bg_user_threepids_grandfather_txn(txn): - sql = """ - INSERT INTO user_threepid_id_server - (user_id, medium, address, id_server) - SELECT user_id, medium, address, ? - FROM user_threepids - """ - - txn.executemany(sql, [(id_server,) for id_server in id_servers]) - - if id_servers: - yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn - ) - - yield self._end_background_update("user_threepids_grandfather") - - return 1 - def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token @@ -1209,6 +1313,10 @@ class RegistrationStore( current_ts (int): The current unix time in milliseconds. Used for checking token expiry status + Raises: + ThreepidValidationError: if a matching validation token was not found or has + expired + Returns: deferred str|None: A str representing a link to redirect the user to if there is one. @@ -1216,7 +1324,7 @@ class RegistrationStore( # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1234,7 +1342,7 @@ class RegistrationStore( 400, "This client_secret does not match the provided session_id" ) - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1259,7 +1367,7 @@ class RegistrationStore( ) # Looks good. Validate the session - self._simple_update_txn( + self.db.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1269,7 +1377,7 @@ class RegistrationStore( return next_link # Return next_link if it exists - return self.runInteraction( + return self.db.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) @@ -1302,7 +1410,7 @@ class RegistrationStore( if validated_at: insertion_values["validated_at"] = validated_at - return self._simple_upsert( + return self.db.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1340,7 +1448,7 @@ class RegistrationStore( def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1353,7 +1461,7 @@ class RegistrationStore( ) # Create a new validation token with this session ID - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1364,7 +1472,7 @@ class RegistrationStore( }, ) - return self.runInteraction( + return self.db.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) @@ -1379,14 +1487,30 @@ class RegistrationStore( """ return txn.execute(sql, (ts,)) - return self.runInteraction( + return self.db.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), ) + @defer.inlineCallbacks + def set_user_deactivated_status(self, user_id, deactivated): + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id (str): The ID of the user to set the status for. + deactivated (bool): The value to set for `deactivated`. + """ + + yield self.db.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1397,17 +1521,57 @@ class RegistrationStore( ) @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): - """Set the `deactivated` property for the provided user to the provided value. + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + yield self.db.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period - yield self.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, ) diff --git a/synapse/storage/rejections.py b/synapse/storage/data_stores/main/rejections.py index f4c1c2a457..27e5a2084a 100644 --- a/synapse/storage/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -15,25 +15,14 @@ import logging -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def _store_rejections_txn(self, txn, event_id, reason): - self._simple_insert_txn( - txn, - table="rejections", - values={ - "event_id": event_id, - "reason": reason, - "last_check": self._clock.time_msec(), - }, - ) - def get_rejection_reason(self, event_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py new file mode 100644 index 0000000000..7d477f8d01 --- /dev/null +++ b/synapse/storage/data_stores/main/relations.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr + +from synapse.api.constants import RelationTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.stream import generate_pagination_where_clause +from synapse.storage.relations import ( + AggregationPaginationToken, + PaginationChunk, + RelationPaginationToken, +) +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +class RelationsWorkerStore(SQLBaseStore): + @cached(tree=True) + def get_relations_for_event( + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type is not None: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type is not None: + where_clause.append("type = ?") + where_args.append(event_type) + + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.db.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + + @cached(tree=True) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.db.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. + + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. + sql = """ + SELECT edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + WHERE + relates_to_id = ? + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute(sql, (event_id, RelationTypes.REPLACE)) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.db.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + return edit_event + + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.db.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + + +class RelationsStore(RelationsWorkerStore): + pass diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py new file mode 100644 index 0000000000..46f643c6b9 --- /dev/null +++ b/synapse/storage/data_stores/main/room.py @@ -0,0 +1,1433 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import logging +import re +from abc import abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import StoreError +from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.search import SearchStore +from synapse.storage.database import Database, LoggingTransaction +from synapse.types import ThirdPartyInstanceID +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +OpsLevel = collections.namedtuple( + "OpsLevel", ("ban_level", "kick_level", "redact_level") +) + +RatelimitOverride = collections.namedtuple( + "RatelimitOverride", ("messages_per_second", "burst_count") +) + + +class RoomSortOrder(Enum): + """ + Enum to define the sorting method used when returning rooms with get_rooms_paginate + + NAME = sort rooms alphabetically by name + JOINED_MEMBERS = sort rooms by membership size, highest to lowest + """ + + # ALPHABETICAL and SIZE are deprecated. + # ALPHABETICAL is the same as NAME. + ALPHABETICAL = "alphabetical" + # SIZE is the same as JOINED_MEMBERS. + SIZE = "size" + NAME = "name" + CANONICAL_ALIAS = "canonical_alias" + JOINED_MEMBERS = "joined_members" + JOINED_LOCAL_MEMBERS = "joined_local_members" + VERSION = "version" + CREATOR = "creator" + ENCRYPTION = "encryption" + FEDERATABLE = "federatable" + PUBLIC = "public" + JOIN_RULES = "join_rules" + GUEST_ACCESS = "guest_access" + HISTORY_VISIBILITY = "history_visibility" + STATE_EVENTS = "state_events" + + +class RoomWorkerStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomWorkerStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A dict containing the room information, or None if the room is unknown. + """ + return self.db.simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("room_id", "is_public", "creator"), + desc="get_room", + allow_none=True, + ) + + def get_room_with_stats(self, room_id: str): + """Retrieve room with statistics. + + Args: + room_id: The ID of the room to retrieve. + Returns: + A dict containing the room information, or None if the room is unknown. + """ + + def get_room_with_stats_txn(txn, room_id): + sql = """ + SELECT room_id, state.name, state.canonical_alias, curr.joined_members, + curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, + rooms.creator, state.encryption, state.is_federatable AS federatable, + rooms.is_public AS public, state.join_rules, state.guest_access, + state.history_visibility, curr.current_state_events AS state_events + FROM rooms + LEFT JOIN room_stats_state state USING (room_id) + LEFT JOIN room_stats_current curr USING (room_id) + WHERE room_id = ? + """ + txn.execute(sql, [room_id]) + res = self.db.cursor_to_dict(txn)[0] + res["federatable"] = bool(res["federatable"]) + res["public"] = bool(res["public"]) + return res + + return self.db.runInteraction( + "get_room_with_stats", get_room_with_stats_txn, room_id + ) + + def get_public_room_ids(self): + return self.db.simple_select_onecol( + table="rooms", + keyvalues={"is_public": True}, + retcol="room_id", + desc="get_public_room_ids", + ) + + def count_public_rooms(self, network_tuple, ignore_non_federatable): + """Counts the number of public rooms as tracked in the room_stats_current + and room_stats_state table. + + Args: + network_tuple (ThirdPartyInstanceID|None) + ignore_non_federatable (bool): If true filters out non-federatable rooms + """ + + def _count_public_rooms_txn(txn): + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ + + sql = """ + SELECT + COALESCE(COUNT(*), 0) + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + """ % { + "published_sql": published_sql + } + + txn.execute(sql, query_args) + return txn.fetchone()[0] + + return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) + + @defer.inlineCallbacks + def get_largest_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + search_filter: Optional[dict], + limit: Optional[int], + bounds: Optional[Tuple[int, str]], + forwards: bool, + ignore_non_federatable: bool = False, + ): + """Gets the largest public rooms (where largest is in terms of joined + members, as tracked in the statistics table). + + Args: + network_tuple + search_filter + limit: Maxmimum number of rows to return, unlimited otherwise. + bounds: An uppoer or lower bound to apply to result set if given, + consists of a joined member count and room_id (these are + excluded from result set). + forwards: true iff going forwards, going backwards otherwise + ignore_non_federatable: If true filters out non-federatable rooms. + + Returns: + Rooms in order: biggest number of joined users first. + We then arbitrarily use the room_id as a tie breaker. + + """ + + where_clauses = [] + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ + + # Work out the bounds if we're given them, these bounds look slightly + # odd, but are designed to help query planner use indices by pulling + # out a common bound. + if bounds: + last_joined_members, last_room_id = bounds + if forwards: + where_clauses.append( + """ + joined_members <= ? AND ( + joined_members < ? OR room_id < ? + ) + """ + ) + else: + where_clauses.append( + """ + joined_members >= ? AND ( + joined_members > ? OR room_id > ? + ) + """ + ) + + query_args += [last_joined_members, last_joined_members, last_room_id] + + if ignore_non_federatable: + where_clauses.append("is_federatable") + + if search_filter and search_filter.get("generic_search_term", None): + search_term = "%" + search_filter["generic_search_term"] + "%" + + where_clauses.append( + """ + ( + LOWER(name) LIKE ? + OR LOWER(topic) LIKE ? + OR LOWER(canonical_alias) LIKE ? + ) + """ + ) + query_args += [ + search_term.lower(), + search_term.lower(), + search_term.lower(), + ] + + where_clause = "" + if where_clauses: + where_clause = " AND " + " AND ".join(where_clauses) + + sql = """ + SELECT + room_id, name, topic, canonical_alias, joined_members, + avatar, history_visibility, joined_members, guest_access + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + %(where_clause)s + ORDER BY joined_members %(dir)s, room_id %(dir)s + """ % { + "published_sql": published_sql, + "where_clause": where_clause, + "dir": "DESC" if forwards else "ASC", + } + + if limit is not None: + query_args.append(limit) + + sql += """ + LIMIT ? + """ + + def _get_largest_public_rooms_txn(txn): + txn.execute(sql, query_args) + + results = self.db.cursor_to_dict(txn) + + if not forwards: + results.reverse() + + return results + + ret_val = yield self.db.runInteraction( + "get_largest_public_rooms", _get_largest_public_rooms_txn + ) + defer.returnValue(ret_val) + + @cached(max_entries=10000) + def is_room_blocked(self, room_id): + return self.db.simple_select_one_onecol( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + retcol="1", + allow_none=True, + desc="is_room_blocked", + ) + + async def get_rooms_paginate( + self, + start: int, + limit: int, + order_by: RoomSortOrder, + reverse_order: bool, + search_term: Optional[str], + ) -> Tuple[List[Dict[str, Any]], int]: + """Function to retrieve a paginated list of rooms as json. + + Args: + start: offset in the list + limit: maximum amount of rooms to retrieve + order_by: the sort order of the returned list + reverse_order: whether to reverse the room list + search_term: a string to filter room names by + Returns: + A list of room dicts and an integer representing the total number of + rooms that exist given this query + """ + # Filter room names by a string + where_statement = "" + if search_term: + where_statement = "WHERE state.name LIKE ?" + + # Our postgres db driver converts ? -> %s in SQL strings as that's the + # placeholder for postgres. + # HOWEVER, if you put a % into your SQL then everything goes wibbly. + # To get around this, we're going to surround search_term with %'s + # before giving it to the database in python instead + search_term = "%" + search_term + "%" + + # Set ordering + if RoomSortOrder(order_by) == RoomSortOrder.SIZE: + # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS + order_by_column = "curr.joined_members" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL: + # Deprecated in favour of RoomSortOrder.NAME + order_by_column = "state.name" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.NAME: + order_by_column = "state.name" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS: + order_by_column = "state.canonical_alias" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS: + order_by_column = "curr.joined_members" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS: + order_by_column = "curr.local_users_in_room" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.VERSION: + order_by_column = "rooms.room_version" + order_by_asc = False + elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR: + order_by_column = "rooms.creator" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION: + order_by_column = "state.encryption" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE: + order_by_column = "state.is_federatable" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC: + order_by_column = "rooms.is_public" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES: + order_by_column = "state.join_rules" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS: + order_by_column = "state.guest_access" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY: + order_by_column = "state.history_visibility" + order_by_asc = True + elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS: + order_by_column = "curr.current_state_events" + order_by_asc = False + else: + raise StoreError( + 500, "Incorrect value for order_by provided: %s" % order_by + ) + + # Whether to return the list in reverse order + if reverse_order: + # Flip the boolean + order_by_asc = not order_by_asc + + # Create one query for getting the limited number of events that the user asked + # for, and another query for getting the total number of events that could be + # returned. Thus allowing us to see if there are more events to paginate through + info_sql = """ + SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members, + curr.local_users_in_room, rooms.room_version, rooms.creator, + state.encryption, state.is_federatable, rooms.is_public, state.join_rules, + state.guest_access, state.history_visibility, curr.current_state_events + FROM room_stats_state state + INNER JOIN room_stats_current curr USING (room_id) + INNER JOIN rooms USING (room_id) + %s + ORDER BY %s %s + LIMIT ? + OFFSET ? + """ % ( + where_statement, + order_by_column, + "ASC" if order_by_asc else "DESC", + ) + + # Use a nested SELECT statement as SQL can't count(*) with an OFFSET + count_sql = """ + SELECT count(*) FROM ( + SELECT room_id FROM room_stats_state state + %s + ) AS get_room_ids + """ % ( + where_statement, + ) + + def _get_rooms_paginate_txn(txn): + # Execute the data query + sql_values = (limit, start) + if search_term: + # Add the search term into the WHERE clause + sql_values = (search_term,) + sql_values + txn.execute(info_sql, sql_values) + + # Refactor room query data into a structured dictionary + rooms = [] + for room in txn: + rooms.append( + { + "room_id": room[0], + "name": room[1], + "canonical_alias": room[2], + "joined_members": room[3], + "joined_local_members": room[4], + "version": room[5], + "creator": room[6], + "encryption": room[7], + "federatable": room[8], + "public": room[9], + "join_rules": room[10], + "guest_access": room[11], + "history_visibility": room[12], + "state_events": room[13], + } + ) + + # Execute the count query + + # Add the search term into the WHERE clause if present + sql_values = (search_term,) if search_term else () + txn.execute(count_sql, sql_values) + + room_count = txn.fetchone() + return rooms, room_count[0] + + return await self.db.runInteraction( + "get_rooms_paginate", _get_rooms_paginate_txn, + ) + + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self.db.simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + return RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + ) + else: + return None + + @cachedInlineCallbacks() + def get_retention_policy_for_room(self, room_id): + """Get the retention policy for a given room. + + If no retention policy has been found for this room, returns a policy defined + by the configured default policy (which has None as both the 'min_lifetime' and + the 'max_lifetime' if no default policy has been defined in the server's + configuration). + + Args: + room_id (str): The ID of the room to get the retention policy of. + + Returns: + dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + """ + + def get_retention_policy_for_room_txn(txn): + txn.execute( + """ + SELECT min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + WHERE room_id = ?; + """, + (room_id,), + ) + + return self.db.cursor_to_dict(txn) + + ret = yield self.db.runInteraction( + "get_retention_policy_for_room", get_retention_policy_for_room_txn, + ) + + # If we don't know this room ID, ret will be None, in this case return the default + # policy. + if not ret: + defer.returnValue( + { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } + ) + + row = ret[0] + + # If one of the room's policy's attributes isn't defined, use the matching + # attribute from the default policy. + # The default values will be None if no default policy has been defined, or if one + # of the attributes is missing from the default policy. + if row["min_lifetime"] is None: + row["min_lifetime"] = self.config.retention_default_min_lifetime + + if row["max_lifetime"] is None: + row["max_lifetime"] = self.config.retention_default_max_lifetime + + defer.returnValue(row) + + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + + return self.db.runInteraction( + "get_media_ids_in_room", _get_media_mxcs_in_room_txn + ) + + def quarantine_media_ids_in_room(self, room_id, quarantined_by): + """For a room loops through all events with media and quarantines + the associated media + """ + + logger.info("Quarantining media in room: %s", room_id) + + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 + + # Now update all the tables to set the quarantined_by flag + + txn.executemany( + """ + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs + ), + ) + + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) + + return total_media_quarantined + + return self.db.runInteraction( + "quarantine_media_in_room", _quarantine_media_in_room_txn + ) + + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + txn (cursor) + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + sql = """ + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) + WHERE room_id = ? + %(where_clause)s + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100)) + + local_media_mxcs = [] + remote_media_mxcs = [] + + while True: + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + event_json = json.loads(content_json) + content = event_json["content"] + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hs.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + if next_token is None: + # We've gone through the whole room, so we're finished. + break + + txn.execute( + sql % {"where_clause": "AND stream_ordering < ?"}, + (room_id, next_token, True, False, 100), + ) + + return local_media_mxcs, remote_media_mxcs + + def quarantine_media_by_id( + self, server_name: str, media_id: str, quarantined_by: str, + ): + """quarantines a single local or remote media id + + Args: + server_name: The name of the server that holds this media + media_id: The ID of the media to be quarantined + quarantined_by: The user ID that initiated the quarantine request + """ + logger.info("Quarantining media: %s/%s", server_name, media_id) + is_local = server_name == self.config.server_name + + def _quarantine_media_by_id_txn(txn): + local_mxcs = [media_id] if is_local else [] + remote_mxcs = [(server_name, media_id)] if not is_local else [] + + return self._quarantine_media_txn( + txn, local_mxcs, remote_mxcs, quarantined_by + ) + + return self.db.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_id_txn + ) + + def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + """quarantines all local media associated with a single user + + Args: + user_id: The ID of the user to quarantine media of + quarantined_by: The ID of the user who made the quarantine request + """ + + def _quarantine_media_by_user_txn(txn): + local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) + return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) + + return self.db.runInteraction( + "quarantine_media_by_user", _quarantine_media_by_user_txn + ) + + def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True): + """Retrieves local media IDs by a given user + + Args: + txn (cursor) + user_id: The ID of the user to retrieve media IDs of + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + # Local media + sql = """ + SELECT media_id + FROM local_media_repository + WHERE user_id = ? + """ + if filter_quarantined: + sql += "AND quarantined_by IS NULL" + txn.execute(sql, (user_id,)) + + local_media_ids = [row[0] for row in txn] + + # TODO: Figure out all remote media a user has referenced in a message + + return local_media_ids + + def _quarantine_media_txn( + self, + txn, + local_mxcs: List[str], + remote_mxcs: List[Tuple[str, str]], + quarantined_by: str, + ) -> int: + """Quarantine local and remote media items + + Args: + txn (cursor) + local_mxcs: A list of local mxc URLs + remote_mxcs: A list of (remote server, media id) tuples representing + remote mxc URLs + quarantined_by: The ID of the user who initiated the quarantine request + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + # Update all the tables to set the quarantined_by flag + txn.executemany( + """ + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, + ((quarantined_by, media_id) for media_id in local_mxcs), + ) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), + ) + + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) + + return total_media_quarantined + + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + + +class RoomBackgroundUpdateStore(SQLBaseStore): + REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" + ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" + + def __init__(self, database: Database, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + + self.db.updates.register_background_update_handler( + "insert_room_retention", self._background_insert_retention, + ) + + self.db.updates.register_background_update_handler( + self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, + self._remove_tombstoned_rooms_from_directory, + ) + + self.db.updates.register_background_update_handler( + self.ADD_ROOMS_ROOM_VERSION_COLUMN, + self._background_add_rooms_room_version_column, + ) + + @defer.inlineCallbacks + def _background_insert_retention(self, progress, batch_size): + """Retrieves a list of all rooms within a range and inserts an entry for each of + them into the room_retention table. + NULLs the property's columns if missing from the retention event in the room's + state (or NULLs all of them if there's no retention event in the room's state), + so that we fall back to the server's retention policy. + """ + + last_room = progress.get("room_id", "") + + def _background_insert_retention_txn(txn): + txn.execute( + """ + SELECT state.room_id, state.event_id, events.json + FROM current_state_events as state + LEFT JOIN event_json AS events ON (state.event_id = events.event_id) + WHERE state.room_id > ? AND state.type = '%s' + ORDER BY state.room_id ASC + LIMIT ?; + """ + % EventTypes.Retention, + (last_room, batch_size), + ) + + rows = self.db.cursor_to_dict(txn) + + if not rows: + return True + + for row in rows: + if not row["json"]: + retention_policy = {} + else: + ev = json.loads(row["json"]) + retention_policy = json.dumps(ev["content"]) + + self.db.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": row["room_id"], + "event_id": row["event_id"], + "min_lifetime": retention_policy.get("min_lifetime"), + "max_lifetime": retention_policy.get("max_lifetime"), + }, + ) + + logger.info("Inserted %d rows into room_retention", len(rows)) + + self.db.updates._background_update_progress_txn( + txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} + ) + + if batch_size > len(rows): + return True + else: + return False + + end = yield self.db.runInteraction( + "insert_room_retention", _background_insert_retention_txn, + ) + + if end: + yield self.db.updates._end_background_update("insert_room_retention") + + defer.returnValue(batch_size) + + async def _background_add_rooms_room_version_column( + self, progress: dict, batch_size: int + ): + """Background update to go and add room version inforamtion to `rooms` + table from `current_state_events` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction): + sql = """ + SELECT room_id, json FROM current_state_events + INNER JOIN event_json USING (room_id, event_id) + WHERE room_id > ? AND type = 'm.room.create' AND state_key = '' + ORDER BY room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + + updates = [] + for room_id, event_json in txn: + event_dict = json.loads(event_json) + room_version_id = event_dict.get("content", {}).get( + "room_version", RoomVersions.V1.identifier + ) + + creator = event_dict.get("content").get("creator") + + updates.append((room_id, creator, room_version_id)) + + if not updates: + return True + + new_last_room_id = "" + for room_id, creator, room_version_id in updates: + # We upsert here just in case we don't already have a row, + # mainly for paranoia as much badness would happen if we don't + # insert the row and then try and get the room version for the + # room. + self.db.simple_upsert_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version_id}, + insertion_values={"is_public": False, "creator": creator}, + ) + new_last_room_id = room_id + + self.db.updates._background_update_progress_txn( + txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id} + ) + + return False + + end = await self.db.runInteraction( + "_background_add_rooms_room_version_column", + _background_add_rooms_room_version_column_txn, + ) + + if end: + await self.db.updates._end_background_update( + self.ADD_ROOMS_ROOM_VERSION_COLUMN + ) + + return batch_size + + async def _remove_tombstoned_rooms_from_directory( + self, progress, batch_size + ) -> int: + """Removes any rooms with tombstone events from the room directory + + Nowadays this is handled by the room upgrade handler, but we may have some + that got left behind + """ + + last_room = progress.get("room_id", "") + + def _get_rooms(txn): + txn.execute( + """ + SELECT room_id + FROM rooms r + INNER JOIN current_state_events cse USING (room_id) + WHERE room_id > ? AND r.is_public + AND cse.type = '%s' AND cse.state_key = '' + ORDER BY room_id ASC + LIMIT ?; + """ + % EventTypes.Tombstone, + (last_room, batch_size), + ) + + return [row[0] for row in txn] + + rooms = await self.db.runInteraction( + "get_tombstoned_directory_rooms", _get_rooms + ) + + if not rooms: + await self.db.updates._end_background_update( + self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE + ) + return 0 + + for room_id in rooms: + logger.info("Removing tombstoned room %s from the directory", room_id) + await self.set_room_is_public(room_id, False) + + await self.db.updates._background_update_progress( + self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} + ) + + return len(rooms) + + @abstractmethod + def set_room_is_public(self, room_id, is_public): + # this will need to be implemented if a background update is performed with + # existing (tombstoned, public) rooms in the database. + # + # It's overridden by RoomStore for the synapse master. + raise NotImplementedError() + + +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + + async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): + """Ensure that the room is stored in the table + + Called when we join a room over federation, and overwrites any room version + currently in the table. + """ + await self.db.simple_upsert( + desc="upsert_room_on_join", + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version.identifier}, + insertion_values={"is_public": False, "creator": ""}, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + + @defer.inlineCallbacks + def store_room( + self, + room_id: str, + room_creator_user_id: str, + is_public: bool, + room_version: RoomVersion, + ): + """Stores a room. + + Args: + room_id: The desired room ID, can be None. + room_creator_user_id: The user ID of the room creator. + is_public: True to indicate that this room should appear in + public room lists. + room_version: The version of the room + Raises: + StoreError if the room could not be stored. + """ + try: + + def store_room_txn(txn, next_id): + self.db.simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + "room_version": room_version.identifier, + }, + ) + if is_public: + self.db.simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + + async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + """ + When we receive an invite over federation, store the version of the room if we + don't already know the room version. + """ + await self.db.simple_upsert( + desc="maybe_store_room_on_invite", + table="rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={ + "room_version": room_version.identifier, + "is_public": False, + "creator": "", + }, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + + @defer.inlineCallbacks + def set_room_is_public(self, room_id, is_public): + def set_room_is_public_txn(txn, next_id): + self.db.simple_update_one_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + ) + + entries = self.db.simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={ + "room_id": room_id, + "appservice_id": None, + "network_id": None, + }, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self.db.simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + "appservice_id": None, + "network_id": None, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.db.runInteraction( + "set_room_is_public", set_room_is_public_txn, next_id + ) + self.hs.get_notifier().on_new_replication_data() + + @defer.inlineCallbacks + def set_room_is_public_appservice( + self, room_id, appservice_id, network_id, is_public + ): + """Edit the appservice/network specific public room list. + + Each appservice can have a number of published room lists associated + with them, keyed off of an appservice defined `network_id`, which + basically represents a single instance of a bridge to a third party + network. + + Args: + room_id (str) + appservice_id (str) + network_id (str) + is_public (bool): Whether to publish or unpublish the room from the + list. + """ + + def set_room_is_public_appservice_txn(txn, next_id): + if is_public: + try: + self.db.simple_insert_txn( + txn, + table="appservice_room_list", + values={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + ) + except self.database_engine.module.IntegrityError: + # We've already inserted, nothing to do. + return + else: + self.db.simple_delete_txn( + txn, + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + ) + + entries = self.db.simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={ + "room_id": room_id, + "appservice_id": appservice_id, + "network_id": network_id, + }, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self.db.simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + "appservice_id": appservice_id, + "network_id": network_id, + }, + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.db.runInteraction( + "set_room_is_public_appservice", + set_room_is_public_appservice_txn, + next_id, + ) + self.hs.get_notifier().on_new_replication_data() + + def get_room_count(self): + """Retrieve a list of all rooms + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return self.db.runInteraction("get_rooms", f) + + def add_event_report( + self, room_id, event_id, user_id, reason, content, received_ts + ): + next_id = self._event_reports_id_gen.get_next() + return self.db.simple_insert( + table="event_reports", + values={ + "id": next_id, + "received_ts": received_ts, + "room_id": room_id, + "event_id": event_id, + "user_id": user_id, + "reason": reason, + "content": json.dumps(content), + }, + desc="add_event_report", + ) + + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() + + @defer.inlineCallbacks + def block_room(self, room_id, user_id): + """Marks the room as blocked. Can be called multiple times. + + Args: + room_id (str): Room to block + user_id (str): Who blocked it + + Returns: + Deferred + """ + yield self.db.simple_upsert( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"user_id": user_id}, + desc="block_room", + ) + yield self.db.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + + @defer.inlineCallbacks + def get_rooms_for_retention_period_in_range( + self, min_ms, max_ms, include_null=False + ): + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null (bool): Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + dict[str, dict]: The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.db.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.db.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + rooms = yield self.db.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + + defer.returnValue(rooms) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py new file mode 100644 index 0000000000..137ebac833 --- /dev/null +++ b/synapse/storage/data_stores/main/roommember.py @@ -0,0 +1,1138 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Iterable, List, Set + +from six import iteritems, itervalues + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import ( + LoggingTransaction, + SQLBaseStore, + make_in_list_sql_clause, +) +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database +from synapse.storage.engines import Sqlite3Engine +from synapse.storage.roommember import ( + GetRoomsForUserWithStreamOrdering, + MemberSummary, + ProfileInfo, + RoomsForUser, +) +from synapse.types import Collection, get_domain_from_id +from synapse.util.async_helpers import Linearizer +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" +_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" + + +class RoomMemberWorkerStore(EventsWorkerStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) + + # Is the current_state_events.membership up to date? Or is the + # background update still running? + self._current_state_events_membership_up_to_date = False + + txn = LoggingTransaction( + db_conn.cursor(), + name="_check_safe_current_state_events_membership_updated", + database_engine=self.database_engine, + ) + self._check_safe_current_state_events_membership_updated_txn(txn) + txn.close() + + if self.hs.config.metrics_flags.known_servers: + self._known_servers_count = 1 + self.hs.get_clock().looping_call( + run_as_background_process, + 60 * 1000, + "_count_known_servers", + self._count_known_servers, + ) + self.hs.get_clock().call_later( + 1000, + run_as_background_process, + "_count_known_servers", + self._count_known_servers, + ) + LaterGauge( + "synapse_federation_known_servers", + "", + [], + lambda: self._known_servers_count, + ) + + @defer.inlineCallbacks + def _count_known_servers(self): + """ + Count the servers that this server knows about. + + The statistic is stored on the class for the + `synapse_federation_known_servers` LaterGauge to collect. + """ + + def _transact(txn): + if isinstance(self.database_engine, Sqlite3Engine): + query = """ + SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) + FROM ( + SELECT rm.user_id as user_id, instr(rm.user_id, ':') + AS pos FROM room_memberships as rm + INNER JOIN current_state_events as c ON rm.event_id = c.event_id + WHERE c.type = 'm.room.member' + ) as out + """ + else: + query = """ + SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) + FROM current_state_events + WHERE type = 'm.room.member' AND membership = 'join'; + """ + txn.execute(query) + return list(txn)[0][0] + + count = yield self.db.runInteraction("get_known_servers", _transact) + + # We always know about ourselves, even if we have nothing in + # room_memberships (for example, the server is new). + self._known_servers_count = max([count, 1]) + return self._known_servers_count + + def _check_safe_current_state_events_membership_updated_txn(self, txn): + """Checks if it is safe to assume the new current_state_events + membership column is up to date + """ + + pending_update = self.db.simple_select_one_txn( + txn, + table="background_updates", + keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, + retcols=["update_name"], + allow_none=True, + ) + + self._current_state_events_membership_up_to_date = not pending_update + + # If the update is still running, reschedule to run. + if pending_update: + self._clock.call_later( + 15.0, + run_as_background_process, + "_check_safe_current_state_events_membership_updated", + self.db.runInteraction, + "_check_safe_current_state_events_membership_updated", + self._check_safe_current_state_events_membership_updated_txn, + ) + + @cached(max_entries=100000, iterable=True) + def get_users_in_room(self, room_id): + return self.db.runInteraction( + "get_users_in_room", self.get_users_in_room_txn, room_id + ) + + def get_users_in_room_txn(self, txn, room_id): + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT state_key FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + """ + else: + sql = """ + SELECT state_key FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + """ + + txn.execute(sql, (room_id, Membership.JOIN)) + return [r[0] for r in txn] + + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + + # If we can assume current_state_events.membership is up to date + # then we can avoid a join, which is a Very Good Thing given how + # frequently this function gets called. + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ + else: + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(membership, MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + if self._current_state_events_membership_up_to_date: + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT state_key, membership, event_id + FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + ORDER BY + CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + event_id ASC + LIMIT ? + """ + else: + sql = """ + SELECT c.state_key, m.membership, c.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c USING (room_id, event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + c.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[membership] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((user_id, event_id)) + + return res + + return self.db.runInteraction("get_room_summary", _get_room_summary_txn) + + def _get_user_counts_in_room_txn(self, txn, room_id): + """ + Get the user count in a room by membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + sql = """ + SELECT m.membership, count(*) FROM room_memberships as m + INNER JOIN current_state_events as c USING(event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + return {row[0]: row[1] for row in txn} + + @cached() + def get_invited_rooms_for_local_user(self, user_id): + """ Get all the rooms the *local* user is invited to + + Args: + user_id (str): The user ID. + Returns: + A deferred list of RoomsForUser. + """ + + return self.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE] + ) + + @defer.inlineCallbacks + def get_invite_for_local_user_in_room(self, user_id, room_id): + """Gets the invite for the given *local* user and room + + Args: + user_id (str) + room_id (str) + + Returns: + Deferred: Resolves to either a RoomsForUser or None if no invite was + found. + """ + invites = yield self.get_invited_rooms_for_local_user(user_id) + for invite in invites: + if invite.room_id == room_id: + return invite + return None + + @defer.inlineCallbacks + def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this *local* user where the membership for this user + matches one in the membership list. + + Filters out forgotten rooms. + + Args: + user_id (str): The user ID. + membership_list (list): A list of synapse.api.constants.Membership + values which the user must be in. + + Returns: + Deferred[list[RoomsForUser]] + """ + if not membership_list: + return defer.succeed(None) + + rooms = yield self.db.runInteraction( + "get_rooms_for_local_user_where_membership_is", + self._get_rooms_for_local_user_where_membership_is_txn, + user_id, + membership_list, + ) + + # Now we filter out forgotten rooms + forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + return [room for room in rooms if room.room_id not in forgotten_rooms] + + def _get_rooms_for_local_user_where_membership_is_txn( + self, txn, user_id, membership_list + ): + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r" + % (user_id,), + ) + + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) + + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM local_current_membership AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + user_id = ? + AND %s + """ % ( + clause, + ) + + txn.execute(sql, (user_id, *args)) + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] + + return results + + @cached(max_entries=500000, iterable=True) + def get_rooms_for_user_with_stream_ordering(self, user_id): + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. + + Args: + user_id (str) + + Returns: + Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns + the rooms the user is in currently, along with the stream ordering + of the most recent join for that user and room. + """ + return self.db.runInteraction( + "get_rooms_for_user_with_stream_ordering", + self._get_rooms_for_user_with_stream_ordering_txn, + user_id, + ) + + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + # We use `current_state_events` here and not `local_current_membership` + # as a) this gets called with remote users and b) this only gets called + # for rooms the server is participating in. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND c.membership = ? + """ + else: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND m.membership = ? + """ + + txn.execute(sql, (user_id, Membership.JOIN)) + results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + + return results + + async def get_users_server_still_shares_room_with( + self, user_ids: Collection[str] + ) -> Set[str]: + """Given a list of users return the set that the server still share a + room with. + """ + + if not user_ids: + return set() + + def _get_users_server_still_shares_room_with_txn(txn): + sql = """ + SELECT state_key FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND %s + GROUP BY state_key + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "state_key", user_ids + ) + + txn.execute(sql % (clause,), args) + + return {row[0] for row in txn} + + return await self.db.runInteraction( + "get_users_server_still_shares_room_with", + _get_users_server_still_shares_room_with_txn, + ) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id, on_invalidate=None): + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. + """ + rooms = yield self.get_rooms_for_user_with_stream_ordering( + user_id, on_invalidate=on_invalidate + ) + return frozenset(r.room_id for r in rooms) + + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + room_ids = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + + user_who_share_room = set() + for room_id in room_ids: + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) + user_who_share_room.update(user_ids) + + return user_who_share_room + + @defer.inlineCallbacks + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + current_state_ids = yield context.get_current_state_ids() + result = yield self._get_joined_users_from_context( + event.room_id, state_group, current_state_ids, event=event, context=context + ) + return result + + @defer.inlineCallbacks + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_users_from_state"): + return ( + yield self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry + ) + ) + + @cachedInlineCallbacks( + num_args=2, cache_context=True, iterable=True, max_entries=100000 + ) + def _get_joined_users_from_context( + self, + room_id, + state_group, + current_state_ids, + cache_context, + event=None, + context=None, + ): + # We don't use `state_group`, it's there so that we can cache based + # on it. However, it's important that it's never None, since two current_states + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + users_in_room = {} + member_event_ids = [ + e_id + for key, e_id in iteritems(current_state_ids) + if key[0] == EventTypes.Member + ] + + if context is not None: + # If we have a context with a delta from a previous state group, + # check if we also have the result from the previous group in cache. + # If we do then we can reuse that result and simply update it with + # any membership changes in `delta_ids` + if context.prev_group and context.delta_ids: + prev_res = self._get_joined_users_from_context.cache.get( + (room_id, context.prev_group), None + ) + if prev_res and isinstance(prev_res, dict): + users_in_room = dict(prev_res) + member_event_ids = [ + e_id + for key, e_id in iteritems(context.delta_ids) + if key[0] == EventTypes.Member + ] + for etype, state_key in context.delta_ids: + if etype == EventTypes.Member: + users_in_room.pop(state_key, None) + + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_cache( + member_event_ids, allow_rejected=False, update_metrics=False + ) + + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[ev_entry.event.state_key] = ProfileInfo( + display_name=ev_entry.event.content.get("displayname", None), + avatar_url=ev_entry.event.content.get("avatar_url", None), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + event_to_memberships = yield self._get_joined_profiles_from_event_ids( + missing_member_event_ids + ) + users_in_room.update((row for row in event_to_memberships.values() if row)) + + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room[event.state_key] = ProfileInfo( + display_name=event.content.get("displayname", None), + avatar_url=event.content.get("avatar_url", None), + ) + + return users_in_room + + @cached(max_entries=10000) + def _get_joined_profile_from_event_id(self, event_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", + inlineCallbacks=True, + ) + def _get_joined_profiles_from_event_ids(self, event_ids): + """For given set of member event_ids check if they point to a join + event and if so return the associated user and profile info. + + Args: + event_ids (Iterable[str]): The member event IDs to lookup + + Returns: + Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + to `user_id` and ProfileInfo (or None if not join event). + """ + + rows = yield self.db.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=event_ids, + retcols=("user_id", "display_name", "avatar_url", "event_id"), + keyvalues={"membership": Membership.JOIN}, + batch_size=500, + desc="_get_membership_from_event_ids", + ) + + return { + row["event_id"]: ( + row["user_id"], + ProfileInfo( + avatar_url=row["avatar_url"], display_name=row["display_name"] + ), + ) + for row in rows + } + + @cachedInlineCallbacks(max_entries=10000) + def is_host_joined(self, room_id, host): + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT state_key FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (event_id) + WHERE m.membership = 'join' + AND type = 'm.room.member' + AND c.room_id = ? + AND state_key LIKE ? + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if "%" in host or "_" in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) + + if not rows: + return False + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + return True + + @defer.inlineCallbacks + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + with Measure(self._clock, "get_joined_hosts"): + return ( + yield self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry + ) + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + cache = yield self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) + + return joined_hosts + + @cached(max_entries=10000) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + @cachedInlineCallbacks(num_args=2) + def did_forget(self, user_id, room_id): + """Returns whether user_id has elected to discard history for room_id. + + Returns False if they have since re-joined.""" + + def f(txn): + sql = ( + "SELECT" + " COUNT(*)" + " FROM" + " room_memberships" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + " AND" + " forgotten = 0" + ) + txn.execute(sql, (user_id, room_id)) + rows = txn.fetchall() + return rows[0][0] + + count = yield self.db.runInteraction("did_forget_membership", f) + return count == 0 + + @cached() + def get_forgotten_rooms_for_user(self, user_id): + """Gets all rooms the user has forgotten. + + Args: + user_id (str) + + Returns: + Deferred[set[str]] + """ + + def _get_forgotten_rooms_for_user_txn(txn): + # This is a slightly convoluted query that first looks up all rooms + # that the user has forgotten in the past, then rechecks that list + # to see if any have subsequently been updated. This is done so that + # we can use a partial index on `forgotten = 1` on the assumption + # that few users will actually forget many rooms. + # + # Note that a room is considered "forgotten" if *all* membership + # events for that user and room have the forgotten field set (as + # when a user forgets a room we update all rows for that user and + # room, not just the current one). + sql = """ + SELECT room_id, ( + SELECT count(*) FROM room_memberships + WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 + ) AS count + FROM room_memberships AS m + WHERE user_id = ? AND forgotten = 1 + GROUP BY room_id, user_id; + """ + txn.execute(sql, (user_id,)) + return {row[0] for row in txn if row[1] == 0} + + return self.db.runInteraction( + "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn + ) + + @defer.inlineCallbacks + def get_rooms_user_has_been_in(self, user_id): + """Get all rooms that the user has ever been in. + + Args: + user_id (str) + + Returns: + Deferred[set[str]]: Set of room IDs. + """ + + room_ids = yield self.db.simple_select_onecol( + table="room_memberships", + keyvalues={"membership": Membership.JOIN, "user_id": user_id}, + retcol="room_id", + desc="get_rooms_user_has_been_in", + ) + + return set(room_ids) + + def get_membership_from_event_ids( + self, member_event_ids: Iterable[str] + ) -> List[dict]: + """Get user_id and membership of a set of event IDs. + """ + + return self.db.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=("user_id", "membership", "event_id"), + keyvalues={}, + batch_size=500, + desc="get_membership_from_event_ids", + ) + + async def is_local_host_in_room_ignoring_users( + self, room_id: str, ignore_users: Collection[str] + ) -> bool: + """Check if there are any local users, excluding those in the given + list, in the room. + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", ignore_users + ) + + sql = """ + SELECT 1 FROM local_current_membership + WHERE + room_id = ? AND membership = ? + AND NOT (%s) + LIMIT 1 + """ % ( + clause, + ) + + def _is_local_host_in_room_ignoring_users_txn(txn): + txn.execute(sql, (room_id, Membership.JOIN, *args)) + + return bool(txn.fetchone()) + + return await self.db.runInteraction( + "is_local_host_in_room_ignoring_users", + _is_local_host_in_room_ignoring_users_txn, + ) + + +class RoomMemberBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + self.db.updates.register_background_update_handler( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + self._background_current_state_membership, + ) + self.db.updates.register_background_index_update( + "room_membership_forgotten_idx", + index_name="room_memberships_user_room_forgotten", + table="room_memberships", + columns=["user_id", "room_id"], + where_clause="forgotten = 1", + ) + + @defer.inlineCallbacks + def _background_add_membership_profile(self, progress, batch_size): + target_min_stream_id = progress.get( + "target_min_stream_id_inclusive", self._min_stream_order_on_start + ) + max_stream_id = progress.get( + "max_stream_id_exclusive", self._stream_order_on_start + 1 + ) + + INSERT_CLUMP_SIZE = 1000 + + def add_membership_profile_txn(txn): + sql = """ + SELECT stream_ordering, event_id, events.room_id, event_json.json + FROM events + INNER JOIN event_json USING (event_id) + INNER JOIN room_memberships USING (event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND type = 'm.room.member' + ORDER BY stream_ordering DESC + LIMIT ? + """ + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = self.db.cursor_to_dict(txn) + if not rows: + return 0 + + min_stream_id = rows[-1]["stream_ordering"] + + to_update = [] + for row in rows: + event_id = row["event_id"] + room_id = row["room_id"] + try: + event_json = json.loads(row["json"]) + content = event_json["content"] + except Exception: + continue + + display_name = content.get("displayname", None) + avatar_url = content.get("avatar_url", None) + + if display_name or avatar_url: + to_update.append((display_name, avatar_url, event_id, room_id)) + + to_update_sql = """ + UPDATE room_memberships SET display_name = ?, avatar_url = ? + WHERE event_id = ? AND room_id = ? + """ + for index in range(0, len(to_update), INSERT_CLUMP_SIZE): + clump = to_update[index : index + INSERT_CLUMP_SIZE] + txn.executemany(to_update_sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self.db.updates._background_update_progress_txn( + txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress + ) + + return len(rows) + + result = yield self.db.runInteraction( + _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn + ) + + if not result: + yield self.db.updates._end_background_update( + _MEMBERSHIP_PROFILE_UPDATE_NAME + ) + + return result + + @defer.inlineCallbacks + def _background_current_state_membership(self, progress, batch_size): + """Update the new membership column on current_state_events. + + This works by iterating over all rooms in alphebetical order. + """ + + def _background_current_state_membership_txn(txn, last_processed_room): + processed = 0 + while processed < batch_size: + txn.execute( + """ + SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? + """, + (last_processed_room,), + ) + row = txn.fetchone() + if not row or not row[0]: + return processed, True + + (next_room,) = row + + sql = """ + UPDATE current_state_events + SET membership = ( + SELECT membership FROM room_memberships + WHERE event_id = current_state_events.event_id + ) + WHERE room_id = ? + """ + txn.execute(sql, (next_room,)) + processed += txn.rowcount + + last_processed_room = next_room + + self.db.updates._background_update_progress_txn( + txn, + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, + {"last_processed_room": last_processed_room}, + ) + + return processed, False + + # If we haven't got a last processed room then just use the empty + # string, which will compare before all room IDs correctly. + last_processed_room = progress.get("last_processed_room", "") + + row_count, finished = yield self.db.runInteraction( + "_background_current_state_membership_update", + _background_current_state_membership_txn, + last_processed_room, + ) + + if finished: + yield self.db.updates._end_background_update( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME + ) + + return row_count + + +class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + + self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.get_forgotten_rooms_for_user, (user_id,) + ) + + return self.db.runInteraction("forget_membership", f) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + return frozenset(self.hosts_to_joined_users) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in iteritems(state_entry.delta_ids): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) + return frozenset(self.hosts_to_joined_users) + + def __len__(self): + return self._len diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql index 5964c5aaac..5964c5aaac 100644 --- a/synapse/storage/schema/delta/12/v12.sql +++ b/synapse/storage/data_stores/main/schema/delta/12/v12.sql diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql index f8649e5d99..f8649e5d99 100644 --- a/synapse/storage/schema/delta/13/v13.sql +++ b/synapse/storage/data_stores/main/schema/delta/13/v13.sql diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql index a831920da6..a831920da6 100644 --- a/synapse/storage/schema/delta/14/v14.sql +++ b/synapse/storage/data_stores/main/schema/delta/14/v14.sql diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql index e4f5e76aec..e4f5e76aec 100644 --- a/synapse/storage/schema/delta/15/appservice_txns.sql +++ b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql index 6b8d0f1ca7..6b8d0f1ca7 100644 --- a/synapse/storage/schema/delta/15/presence_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql index 9523d2bcc3..9523d2bcc3 100644 --- a/synapse/storage/schema/delta/15/v15.sql +++ b/synapse/storage/data_stores/main/schema/delta/15/v15.sql diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql index a48f215170..a48f215170 100644 --- a/synapse/storage/schema/delta/16/events_order_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql index 7a15265cb1..7a15265cb1 100644 --- a/synapse/storage/schema/delta/16/remote_media_cache_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql index 65c97b5e2f..65c97b5e2f 100644 --- a/synapse/storage/schema/delta/16/remove_duplicates.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql index f82486132b..f82486132b 100644 --- a/synapse/storage/schema/delta/16/room_alias_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql index 5b8de52c33..5b8de52c33 100644 --- a/synapse/storage/schema/delta/16/unique_constraints.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql index cd0709250d..cd0709250d 100644 --- a/synapse/storage/schema/delta/16/users.sql +++ b/synapse/storage/data_stores/main/schema/delta/16/users.sql diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql index 7c9a90e27f..7c9a90e27f 100644 --- a/synapse/storage/schema/delta/17/drop_indexes.sql +++ b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql index 70b247a06b..70b247a06b 100644 --- a/synapse/storage/schema/delta/17/server_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql index c17715ac80..c17715ac80 100644 --- a/synapse/storage/schema/delta/17/user_threepids.sql +++ b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql diff --git a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql index 6e0871c92b..6e0871c92b 100644 --- a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql +++ b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql index 18b97b4332..18b97b4332 100644 --- a/synapse/storage/schema/delta/19/event_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql index e0ac49d1ec..e0ac49d1ec 100644 --- a/synapse/storage/schema/delta/20/dummy.sql +++ b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py index 3edfcfd783..3edfcfd783 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/data_stores/main/schema/delta/20/pushers.py diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql index 4c2fb20b77..4c2fb20b77 100644 --- a/synapse/storage/schema/delta/21/end_to_end_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql index d070845477..d070845477 100644 --- a/synapse/storage/schema/delta/21/receipts.sql +++ b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql index bfc0b3bcaa..bfc0b3bcaa 100644 --- a/synapse/storage/schema/delta/22/receipts_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql diff --git a/synapse/storage/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql index 87edfa454c..87edfa454c 100644 --- a/synapse/storage/schema/delta/22/user_threepids_unique.sql +++ b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql index acea7483bd..acea7483bd 100644 --- a/synapse/storage/schema/delta/24/stats_reporting.sql +++ b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py index 4b2ffd35fd..4b2ffd35fd 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql index 1ea389b471..1ea389b471 100644 --- a/synapse/storage/schema/delta/25/guest_access.sql +++ b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql index f468fc1897..f468fc1897 100644 --- a/synapse/storage/schema/delta/25/history_visibility.sql +++ b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql index 7a32ce68e4..7a32ce68e4 100644 --- a/synapse/storage/schema/delta/25/tags.sql +++ b/synapse/storage/data_stores/main/schema/delta/25/tags.sql diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql index e395de2b5e..e395de2b5e 100644 --- a/synapse/storage/schema/delta/26/account_data.sql +++ b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql index bf0558b5b3..bf0558b5b3 100644 --- a/synapse/storage/schema/delta/27/account_data.sql +++ b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql index e2094f37fe..e2094f37fe 100644 --- a/synapse/storage/schema/delta/27/forgotten_memberships.sql +++ b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py index 414f9f5aa0..414f9f5aa0 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql index 4d519849df..4d519849df 100644 --- a/synapse/storage/schema/delta/28/event_push_actions.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql index 36609475f1..36609475f1 100644 --- a/synapse/storage/schema/delta/28/events_room_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql index 6c1fd68c5b..6c1fd68c5b 100644 --- a/synapse/storage/schema/delta/28/public_roms_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql index cb84c69baa..cb84c69baa 100644 --- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql index 3e4a9ab455..3e4a9ab455 100644 --- a/synapse/storage/schema/delta/28/upgrade_times.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql index 21d2b420bf..21d2b420bf 100644 --- a/synapse/storage/schema/delta/28/users_is_guest.sql +++ b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql index 84b21cf813..84b21cf813 100644 --- a/synapse/storage/schema/delta/29/push_actions.sql +++ b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql index c9d0dde638..c9d0dde638 100644 --- a/synapse/storage/schema/delta/30/alias_creator.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py index 9b95411fb6..9b95411fb6 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql index 712c454aa1..712c454aa1 100644 --- a/synapse/storage/schema/delta/30/deleted_pushers.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql index 606bbb037d..606bbb037d 100644 --- a/synapse/storage/schema/delta/30/presence_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql index f09db4faa6..f09db4faa6 100644 --- a/synapse/storage/schema/delta/30/public_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql index 735aa8d5f6..735aa8d5f6 100644 --- a/synapse/storage/schema/delta/30/push_rule_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql index 0dd2f1360c..0dd2f1360c 100644 --- a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql +++ b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql index 2c57846d5a..2c57846d5a 100644 --- a/synapse/storage/schema/delta/31/invites.sql +++ b/synapse/storage/data_stores/main/schema/delta/31/invites.sql diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql index 9efb4280eb..9efb4280eb 100644 --- a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py index 9bb504aad5..9bb504aad5 100644 --- a/synapse/storage/schema/delta/31/pushers.py +++ b/synapse/storage/data_stores/main/schema/delta/31/pushers.py diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql index a82add88fd..a82add88fd 100644 --- a/synapse/storage/schema/delta/31/pushers_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py index 7d8ca5f93f..7d8ca5f93f 100644 --- a/synapse/storage/schema/delta/31/search_update.py +++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql index 1dd0f9e170..1dd0f9e170 100644 --- a/synapse/storage/schema/delta/32/events.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/events.sql diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql index 36f37b11c8..36f37b11c8 100644 --- a/synapse/storage/schema/delta/32/openid.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/openid.sql diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql index d86d30c13c..d86d30c13c 100644 --- a/synapse/storage/schema/delta/32/pusher_throttle.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql index 4219cdd06a..2de50d408c 100644 --- a/synapse/storage/schema/delta/32/remove_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql @@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql index d13609776f..d13609776f 100644 --- a/synapse/storage/schema/delta/32/reports.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/reports.sql diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql index 61ad3fe3e8..61ad3fe3e8 100644 --- a/synapse/storage/schema/delta/33/access_tokens_device_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql index eca7268d82..eca7268d82 100644 --- a/synapse/storage/schema/delta/33/devices.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/devices.sql diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql index aa4a3b9f2f..aa4a3b9f2f 100644 --- a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql index 6671573398..6671573398 100644 --- a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py index bff1256a7b..bff1256a7b 100644 --- a/synapse/storage/schema/delta/33/event_fields.py +++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py index a26057dfb6..a26057dfb6 100644 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql index 473f75a78e..473f75a78e 100644 --- a/synapse/storage/schema/delta/33/user_ips_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql index 69e16eda0f..69e16eda0f 100644 --- a/synapse/storage/schema/delta/34/appservice_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py index cf09e43e2b..cf09e43e2b 100644 --- a/synapse/storage/schema/delta/34/cache_stream.py +++ b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql index e68844c74a..e68844c74a 100644 --- a/synapse/storage/schema/delta/34/device_inbox.sql +++ b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql index 0d9fe1a99a..0d9fe1a99a 100644 --- a/synapse/storage/schema/delta/34/push_display_name_rename.sql +++ b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py index 67d505e68b..67d505e68b 100644 --- a/synapse/storage/schema/delta/34/received_txn_purge.py +++ b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql index 6cd123027b..6cd123027b 100644 --- a/synapse/storage/schema/delta/35/contains_url.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql index 17e6c43105..17e6c43105 100644 --- a/synapse/storage/schema/delta/35/device_outbox.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql index 7ab7d942e2..7ab7d942e2 100644 --- a/synapse/storage/schema/delta/35/device_stream_id.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql index 2e836d8e9c..2e836d8e9c 100644 --- a/synapse/storage/schema/delta/35/event_push_actions_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql index dd2bf2e28a..dd2bf2e28a 100644 --- a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql index 2b945d8a57..2b945d8a57 100644 --- a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql +++ b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql index 90d8fd18f9..90d8fd18f9 100644 --- a/synapse/storage/schema/delta/36/readd_public_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py index a377884169..a377884169 100644 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql index cf7a90dd10..cf7a90dd10 100644 --- a/synapse/storage/schema/delta/37/user_threepids.sql +++ b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql index 515e6b8e84..515e6b8e84 100644 --- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql +++ b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql index 74bdc49073..74bdc49073 100644 --- a/synapse/storage/schema/delta/39/appservice_room_list.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql index 00be801e90..00be801e90 100644 --- a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql index de2ad93e5c..de2ad93e5c 100644 --- a/synapse/storage/schema/delta/39/event_push_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql index 5af814290b..5af814290b 100644 --- a/synapse/storage/schema/delta/39/federation_out_position.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql index 1bf911c8ab..1bf911c8ab 100644 --- a/synapse/storage/schema/delta/39/membership_profile.sql +++ b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql index 7ffa189f39..7ffa189f39 100644 --- a/synapse/storage/schema/delta/40/current_state_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql index b9fe1f0480..b9fe1f0480 100644 --- a/synapse/storage/schema/delta/40/device_inbox.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql index dd6dcb65f1..dd6dcb65f1 100644 --- a/synapse/storage/schema/delta/40/device_list_streams.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql index 3918f0b794..3918f0b794 100644 --- a/synapse/storage/schema/delta/40/event_push_summary.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql index 054a223f14..054a223f14 100644 --- a/synapse/storage/schema/delta/40/pushers.sql +++ b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql index b7bee8b692..b7bee8b692 100644 --- a/synapse/storage/schema/delta/41/device_list_stream_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql index 62f0b9892b..62f0b9892b 100644 --- a/synapse/storage/schema/delta/41/device_outbound_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql index 5d9cfecf36..5d9cfecf36 100644 --- a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql index a194bf0238..a194bf0238 100644 --- a/synapse/storage/schema/delta/41/ratelimit.sql +++ b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql index d28851aff8..d28851aff8 100644 --- a/synapse/storage/schema/delta/42/current_state_delta.sql +++ b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql index 9ab8c14fa3..9ab8c14fa3 100644 --- a/synapse/storage/schema/delta/42/device_list_last_id.sql +++ b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql index b8821ac759..b8821ac759 100644 --- a/synapse/storage/schema/delta/42/event_auth_state_only.sql +++ b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py index 506f326f4d..506f326f4d 100644 --- a/synapse/storage/schema/delta/42/user_dir.py +++ b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql index 0e3cd143ff..0e3cd143ff 100644 --- a/synapse/storage/schema/delta/43/blocked_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql index 630907ec4f..630907ec4f 100644 --- a/synapse/storage/schema/delta/43/quarantine_media.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql index 45ebe020da..45ebe020da 100644 --- a/synapse/storage/schema/delta/43/url_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql index ee7062abe4..ee7062abe4 100644 --- a/synapse/storage/schema/delta/43/user_share.sql +++ b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql index b12f9b2ebf..b12f9b2ebf 100644 --- a/synapse/storage/schema/delta/44/expire_url_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql index b2333848a0..b2333848a0 100644 --- a/synapse/storage/schema/delta/45/group_server.sql +++ b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql index e5ddc84df0..e5ddc84df0 100644 --- a/synapse/storage/schema/delta/45/profile_cache.sql +++ b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql index 68c48a89a9..68c48a89a9 100644 --- a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql index bb307889c1..bb307889c1 100644 --- a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql index 097679bc9a..097679bc9a 100644 --- a/synapse/storage/schema/delta/46/group_server.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql index bbfc7f5d1a..bbfc7f5d1a 100644 --- a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql index cb0d5a2576..cb0d5a2576 100644 --- a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql index d9505f8da1..d9505f8da1 100644 --- a/synapse/storage/schema/delta/46/user_dir_typos.sql +++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql index f505fb22b5..f505fb22b5 100644 --- a/synapse/storage/schema/delta/47/last_access_media.sql +++ b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql index 31d7a817eb..31d7a817eb 100644 --- a/synapse/storage/schema/delta/47/postgres_fts_gin.sql +++ b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql index edccf4a96f..edccf4a96f 100644 --- a/synapse/storage/schema/delta/47/push_actions_staging.sql +++ b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql index 5237491506..5237491506 100644 --- a/synapse/storage/schema/delta/48/add_user_consent.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql index 9248b0b24a..9248b0b24a 100644 --- a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql index e9013a6969..e9013a6969 100644 --- a/synapse/storage/schema/delta/48/deactivated_users.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py index 49f5f2c003..49f5f2c003 100644 --- a/synapse/storage/schema/delta/48/group_unique_indexes.py +++ b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql index ce26eaf0c9..ce26eaf0c9 100644 --- a/synapse/storage/schema/delta/48/groups_joinable.sql +++ b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql index 14dcf18d73..14dcf18d73 100644 --- a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql index 3dd478196f..3dd478196f 100644 --- a/synapse/storage/schema/delta/49/add_user_daily_visits.sql +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql index 3a4ed59b5b..3a4ed59b5b 100644 --- a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql index c93ae47532..c93ae47532 100644 --- a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql index 5d8641a9ab..5d8641a9ab 100644 --- a/synapse/storage/schema/delta/50/erasure_store.sql +++ b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py index b1684a8441..b1684a8441 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql index c0e66a697d..c0e66a697d 100644 --- a/synapse/storage/schema/delta/51/e2e_room_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql index c9d537d5a3..c9d537d5a3 100644 --- a/synapse/storage/schema/delta/51/monthly_active_users.sql +++ b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql index 91e03d13e1..91e03d13e1 100644 --- a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql index bfa49e6f92..bfa49e6f92 100644 --- a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql diff --git a/synapse/storage/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql index db687cccae..db687cccae 100644 --- a/synapse/storage/schema/delta/52/e2e_room_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql index 88ec2f83e5..88ec2f83e5 100644 --- a/synapse/storage/schema/delta/53/add_user_type_to_users.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql diff --git a/synapse/storage/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql index e372f5a44a..e372f5a44a 100644 --- a/synapse/storage/schema/delta/53/drop_sent_transactions.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql index 1d977c2834..1d977c2834 100644 --- a/synapse/storage/schema/delta/53/event_format_version.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql index ffcc896b58..ffcc896b58 100644 --- a/synapse/storage/schema/delta/53/user_dir_populate.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql index b812c5794f..b812c5794f 100644 --- a/synapse/storage/schema/delta/53/user_ips_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql index 5831b1a6f8..5831b1a6f8 100644 --- a/synapse/storage/schema/delta/53/user_share.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql diff --git a/synapse/storage/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql index 80c2c573b6..80c2c573b6 100644 --- a/synapse/storage/schema/delta/53/user_threepid_id.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql index f7827ca6d2..f7827ca6d2 100644 --- a/synapse/storage/schema/delta/53/users_in_public_rooms.sql +++ b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql index 0adb2ad55e..0adb2ad55e 100644 --- a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql index c01aa9d2d9..c01aa9d2d9 100644 --- a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql index b062ec840c..b062ec840c 100644 --- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql index dbbe682697..dbbe682697 100644 --- a/synapse/storage/schema/delta/54/drop_legacy_tables.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql index e6ee70c623..e6ee70c623 100644 --- a/synapse/storage/schema/delta/54/drop_presence_list.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql index 134862b870..134862b870 100644 --- a/synapse/storage/schema/delta/54/relations.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/relations.sql diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql index 652e58308e..652e58308e 100644 --- a/synapse/storage/schema/delta/54/stats.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/stats.sql diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql index 3b2d48447f..3b2d48447f 100644 --- a/synapse/storage/schema/delta/54/stats2.sql +++ b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql index 4590604bfd..4590604bfd 100644 --- a/synapse/storage/schema/delta/55/access_token_expiry.sql +++ b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql index a8eced2e0a..a8eced2e0a 100644 --- a/synapse/storage/schema/delta/55/track_threepid_validations.sql +++ b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql index dabdde489b..dabdde489b 100644 --- a/synapse/storage/schema/delta/55/users_alter_deactivated.sql +++ b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql diff --git a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql index 41807eb1e7..41807eb1e7 100644 --- a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql index 473018676f..473018676f 100644 --- a/synapse/storage/schema/delta/56/current_state_events_membership.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql diff --git a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql index 3133d42d4a..3133d42d4a 100644 --- a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql new file mode 100644 index 0000000000..1d2ddb1b1a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* delete room keys that belong to deleted room key version, or to room key + * versions that don't exist (anymore) + */ +DELETE FROM e2e_room_keys +WHERE version NOT IN ( + SELECT version + FROM e2e_room_keys_versions + WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id + AND e2e_room_keys_versions.deleted = 0 +); diff --git a/synapse/storage/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql index f00889290b..f00889290b 100644 --- a/synapse/storage/schema/delta/56/destinations_failure_ts.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres new file mode 100644 index 0000000000..b9bbb18a91 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres @@ -0,0 +1,18 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We want to store large retry intervals so we upgrade the column from INT +-- to BIGINT. We don't need to do this on SQLite. +ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql new file mode 100644 index 0000000000..c2f557fde9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This line already existed in deltas/35/device_stream_id but was not included in the +-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist +INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS ( + SELECT * from device_max_stream_id +); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql new file mode 100644 index 0000000000..dfa902d0ba --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track last seen information for a device in the devices table, rather +-- than relying on it being in the user_ips table (which we want to be able +-- to purge old entries from) +ALTER TABLE devices ADD COLUMN last_seen BIGINT; +ALTER TABLE devices ADD COLUMN ip TEXT; +ALTER TABLE devices ADD COLUMN user_agent TEXT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('devices_last_seen', '{}'); diff --git a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql index 9f09922c67..9f09922c67 100644 --- a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 0000000000..81a36a8b1d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql new file mode 100644 index 0000000000..5e29c1da19 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql @@ -0,0 +1,30 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- room_id and topoligical_ordering are denormalised from the events table in order to +-- make the index work. +CREATE TABLE IF NOT EXISTS event_labels ( + event_id TEXT, + label TEXT, + room_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + PRIMARY KEY(event_id, label) +); + + +-- This index enables an event pagination looking for a particular label to index the +-- event_labels table first, which is much quicker than scanning the events table and then +-- filtering by label, if the label is rarely used relative to the size of the room. +CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering); diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql new file mode 100644 index 0000000000..5f5e0499ae --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('event_store_labels', '{}'); diff --git a/synapse/storage/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql index 014cb3b538..014cb3b538 100644 --- a/synapse/storage/schema/delta/56/fix_room_keys_index.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql new file mode 100644 index 0000000000..67f8b20297 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- device list needs to know which ones are "real" devices, and which ones are +-- just used to avoid collisions +ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE; diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite new file mode 100644 index 0000000000..e8b1fd35d8 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite @@ -0,0 +1,42 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Change the hidden column from a default value of FALSE to a default value of + * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the + * string 'FALSE', which is truthy. + * + * Since sqlite doesn't allow us to just change the default value, we have to + * recreate the table, copy the data, fix the rows that have incorrect data, and + * replace the old table with the new table. + */ + +CREATE TABLE IF NOT EXISTS devices2 ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + last_seen BIGINT, + ip TEXT, + user_agent TEXT, + hidden BOOLEAN DEFAULT 0, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); + +INSERT INTO devices2 SELECT * FROM devices; + +UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE'; + +DROP TABLE devices; + +ALTER TABLE devices2 RENAME TO devices; diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql new file mode 100644 index 0000000000..4f24c1405d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql @@ -0,0 +1,29 @@ +/* Copyright 2019 Werner Sembach + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made. +DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql new file mode 100644 index 0000000000..7be31ffebb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id); diff --git a/synapse/storage/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql index fe51b02309..ea95db0ed7 100644 --- a/synapse/storage/schema/delta/56/redaction_censor.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql @@ -14,4 +14,3 @@ */ ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false; -CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored; diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql new file mode 100644 index 0000000000..49ce35d794 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE redactions ADD COLUMN received_ts BIGINT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('redactions_received_ts', '{}'); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('redactions_have_censored_ts_idx', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres new file mode 100644 index 0000000000..67471f3ef5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- There was a bug where we may have updated censored redactions as bytes, +-- which can (somehow) cause json to be inserted hex encoded. These updates go +-- and undoes any such hex encoded JSON. + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_fix_redactions_bytes_create_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql new file mode 100644 index 0000000000..b7550f6f4e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP INDEX IF EXISTS redactions_have_censored; diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql new file mode 100644 index 0000000000..aeb17813d3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Now that #6232 is a thing, we can remove old rooms from the directory. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('remove_tombstoned_rooms_from_directory', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql new file mode 100644 index 0000000000..7d70dd071e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current etag of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/synapse/storage/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql index 92ab1f5e65..92ab1f5e65 100644 --- a/synapse/storage/schema/delta/56/room_membership_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql new file mode 100644 index 0000000000..ee6cdf7a14 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql @@ -0,0 +1,33 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks the retention policy of a room. +-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in +-- the room's retention policy state event. +-- If a room doesn't have a retention policy state event in its state, both max_lifetime +-- and min_lifetime are NULL. +CREATE TABLE IF NOT EXISTS room_retention( + room_id TEXT, + event_id TEXT, + min_lifetime BIGINT, + max_lifetime BIGINT, + + PRIMARY KEY(room_id, event_id) +); + +CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('insert_room_retention', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql new file mode 100644 index 0000000000..5c5fffcafb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql @@ -0,0 +1,56 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- cross-signing keys +CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( + user_id TEXT NOT NULL, + -- the type of cross-signing key (master, user_signing, or self_signing) + keytype TEXT NOT NULL, + -- the full key information, as a json-encoded dict + keydata TEXT NOT NULL, + -- for keeping the keys in order, so that we can fetch the latest one + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id); + +-- cross-signing signatures +CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( + -- user who did the signing + user_id TEXT NOT NULL, + -- key used to sign + key_id TEXT NOT NULL, + -- user who was signed + target_user_id TEXT NOT NULL, + -- device/key that was signed + target_device_id TEXT NOT NULL, + -- the actual signature + signature TEXT NOT NULL +); + +-- replaced by the index created in signing_keys_nonunique_signatures.sql +-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); + +-- stream of user signature updates +CREATE TABLE IF NOT EXISTS user_signature_stream ( + -- uses the same stream ID as device list stream + stream_id BIGINT NOT NULL, + -- user who did the signing + from_user_id TEXT NOT NULL, + -- list of users who were signed, as a JSON array + user_ids TEXT NOT NULL +); + +CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql new file mode 100644 index 0000000000..0aa90ebf0c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* The cross-signing signatures index should not be a unique index, because a + * user may upload multiple signatures for the same target user. The previous + * index was unique, so delete it if it's there and create a new non-unique + * index. */ + +DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT +EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); diff --git a/synapse/storage/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql index 163529c071..bbdde121e8 100644 --- a/synapse/storage/schema/delta/56/stats_separated.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql @@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN ( 'populate_stats_cleanup' ); +-- this relies on current_state_events.membership having been populated, so add +-- a dependency on current_state_events_membership. INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('populate_stats_process_rooms', '{}', ''); + ('populate_stats_process_rooms', '{}', 'current_state_events_membership'); +-- this also relies on current_state_events.membership having been populated, but +-- we get that as a side-effect of depending on populate_stats_process_rooms. INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES ('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py new file mode 100644 index 0000000000..1de8b54961 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py @@ -0,0 +1,52 @@ +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +""" +This migration updates the user_filters table as follows: + + - drops any (user_id, filter_id) duplicates + - makes the columns NON-NULLable + - turns the index into a UNIQUE index +""" + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + select_clause = """ + SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json + FROM user_filters + """ + else: + select_clause = """ + SELECT * FROM user_filters GROUP BY user_id, filter_id + """ + sql = """ + DROP TABLE IF EXISTS user_filters_migration; + DROP INDEX IF EXISTS user_filters_unique; + CREATE TABLE user_filters_migration ( + user_id TEXT NOT NULL, + filter_id BIGINT NOT NULL, + filter_json BYTEA NOT NULL + ); + INSERT INTO user_filters_migration (user_id, filter_id, filter_json) + %s; + CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration + (user_id, filter_id); + DROP TABLE user_filters; + ALTER TABLE user_filters_migration RENAME TO user_filters; + """ % ( + select_clause, + ) + + if isinstance(database_engine, PostgresEngine): + cur.execute(sql) + else: + cur.executescript(sql) diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql new file mode 100644 index 0000000000..91390c4527 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql @@ -0,0 +1,24 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * a table which records mappings from external auth providers to mxids + */ +CREATE TABLE IF NOT EXISTS user_external_ids ( + auth_provider TEXT NOT NULL, + external_id TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (auth_provider, external_id) +); diff --git a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql index 149f8be8b6..149f8be8b6 100644 --- a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql new file mode 100644 index 0000000000..aec06c8261 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add background update to go and delete current state events for rooms the +-- server is no longer in. +-- +-- this relies on the 'membership' column of current_state_events, so make sure +-- that's populated first! +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('delete_old_current_state_events', '{}', 'current_state_events_membership'); diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql new file mode 100644 index 0000000000..c3b6de2099 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Records whether the server thinks that the remote users cached device lists +-- may be out of date (e.g. if we have received a to device message from a +-- device we don't know about). +CREATE TABLE IF NOT EXISTS device_lists_remote_resync ( + user_id TEXT NOT NULL, + added_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync (user_id); +CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync (added_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py new file mode 100644 index 0000000000..63b5acdcf7 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# We create a new table called `local_current_membership` that stores the latest +# membership state of local users in rooms, which helps track leaves/bans/etc +# even if the server has left the room (and so has deleted the room from +# `current_state_events`). This will also include outstanding invites for local +# users for rooms the server isn't in. +# +# If the server isn't and hasn't been in the room then it will only include +# outsstanding invites, and not e.g. pre-emptive bans of local users. +# +# If the server later rejoins a room `local_current_membership` can simply be +# replaced with the new current state of the room (which results in the +# equivalent behaviour as if the server had remained in the room). + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # We need to do the insert in `run_upgrade` section as we don't have access + # to `config` in `run_create`. + + # This upgrade may take a bit of time for large servers (e.g. one minute for + # matrix.org) but means we avoid a lots of book keeping required to do it as + # a background update. + + # We check if the `current_state_events.membership` is up to date by + # checking if the relevant background update has finished. If it has + # finished we can avoid doing a join against `room_memberships`, which + # speesd things up. + cur.execute( + """SELECT 1 FROM background_updates + WHERE update_name = 'current_state_events_membership' + """ + ) + current_state_membership_up_to_date = not bool(cur.fetchone()) + + # Cheekily drop and recreate indices, as that is faster. + cur.execute("DROP INDEX local_current_membership_idx") + cur.execute("DROP INDEX local_current_membership_room_idx") + + if current_state_membership_up_to_date: + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, c.membership + FROM current_state_events AS c + WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ? + """ + else: + # We can't rely on the membership column, so we need to join against + # `room_memberships`. + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, r.membership + FROM current_state_events AS c + INNER JOIN room_memberships AS r USING (event_id) + WHERE type = 'm.room.member' AND state_key LIKE ? + """ + sql = database_engine.convert_param_style(sql) + cur.execute(sql, ("%:" + config.server_name,)) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute( + """ + CREATE TABLE local_current_membership ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + membership TEXT NOT NULL + )""" + ) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql new file mode 100644 index 0000000000..133d80af35 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql @@ -0,0 +1,21 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- we no longer keep sent outbound device pokes in the db; clear them out +-- so that we don't have to worry about them. +-- +-- This is a sequence scan, but it doesn't take too long. + +DELETE FROM device_lists_outbound_pokes WHERE sent; diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql new file mode 100644 index 0000000000..352a66f5b0 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We want to start storing the room version independently of +-- `current_state_events` so that we can delete stale entries from it without +-- losing the information. +ALTER TABLE rooms ADD COLUMN room_version TEXT; + + +INSERT into background_updates (update_name, progress_json) + VALUES ('add_rooms_room_version_column', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres new file mode 100644 index 0000000000..c601cff6de --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres @@ -0,0 +1,35 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- when we first added the room_version column, it was populated via a background +-- update. We now need it to be populated before synapse starts, so we populate +-- any remaining rows with a NULL room version now. For servers which have completed +-- the background update, this will be pretty quick. + +-- the following query will set room_version to NULL if no create event is found for +-- the room in current_state_events, and will set it to '1' if a create event with no +-- room_version is found. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json::json->'content'->>'room_version','1') + FROM current_state_events cse INNER JOIN event_json ej USING (event_id) + WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key='' +) WHERE rooms.room_version IS NULL; + +-- we still allow the background update to complete: it has the useful side-effect of +-- populating `rooms` with any missing rooms (based on the current_state_events table). + +-- see also rooms_version_column_2.sql.sqlite which has a copy of the above query, using +-- sqlite syntax for the json extraction. diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite new file mode 100644 index 0000000000..335c6f2074 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- see rooms_version_column_2.sql.postgres for details of what's going on here. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') + FROM current_state_events cse INNER JOIN event_json ej USING (event_id) + WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key='' +) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres new file mode 100644 index 0000000000..92aaadde0d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres @@ -0,0 +1,39 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- When we first added the room_version column to the rooms table, it was populated from +-- the current_state_events table. However, there was an issue causing a background +-- update to clean up the current_state_events table for rooms where the server is no +-- longer participating, before that column could be populated. Therefore, some rooms had +-- a NULL room_version. + +-- The rooms_version_column_2.sql.* delta files were introduced to make the populating +-- synchronous instead of running it in a background update, which fixed this issue. +-- However, all of the instances of Synapse installed or updated in the meantime got +-- their rooms table corrupted with NULL room_versions. + +-- This query fishes out the room versions from the create event using the state_events +-- table instead of the current_state_events one, as the former still have all of the +-- create events. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json::json->'content'->>'room_version','1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; + +-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using +-- sqlite syntax for the json extraction. diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite new file mode 100644 index 0000000000..e19dab97cb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite @@ -0,0 +1,23 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- see rooms_version_column_3.sql.postgres for details of what's going on here. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql new file mode 100644 index 0000000000..fdc39e9ba5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* for some reason, we have accumulated duplicate entries in + * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + * efficient. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES (5800, 'remove_dup_outbound_pokes', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql new file mode 100644 index 0000000000..dcb593fc2d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql @@ -0,0 +1,36 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS ui_auth_sessions( + session_id TEXT NOT NULL, -- The session ID passed to the client. + creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds). + serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse. + clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client. + uri TEXT NOT NULL, -- The URI the UI authentication session is using. + method TEXT NOT NULL, -- The HTTP method the UI authentication session is using. + -- The clientdict, uri, and method make up an tuple that must be immutable + -- throughout the lifetime of the UI Auth session. + description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur. + UNIQUE (session_id) +); + +CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials( + session_id TEXT NOT NULL, -- The corresponding UI Auth session. + stage_type TEXT NOT NULL, -- The stage type. + result TEXT NOT NULL, -- The result of the stage verification, stored as JSON. + UNIQUE (session_id, stage_type), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres new file mode 100644 index 0000000000..aa46eb0e10 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We keep the old table here to enable us to roll back. It doesn't matter +-- that we have dropped all the data here. +TRUNCATE cache_invalidation_stream; + +CREATE TABLE cache_invalidation_stream_by_instance ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + cache_func TEXT NOT NULL, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id); + +CREATE SEQUENCE cache_invalidation_stream_seq; diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py new file mode 100644 index 0000000000..d353f2bcb3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py @@ -0,0 +1,80 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This migration rebuilds the device_lists_outbound_last_success table without duplicate +entries, and with a UNIQUE index. +""" + +import logging +from io import StringIO + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream +from synapse.storage.types import Cursor + +logger = logging.getLogger(__name__) + + +def run_upgrade(*args, **kwargs): + pass + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # some instances might already have this index, in which case we can skip this + if isinstance(database_engine, PostgresEngine): + cur.execute( + """ + SELECT 1 FROM pg_class WHERE relkind = 'i' + AND relname = 'device_lists_outbound_last_success_unique_idx' + """ + ) + + if cur.rowcount: + logger.info( + "Unique index exists on device_lists_outbound_last_success: " + "skipping rebuild" + ) + return + + logger.info("Rebuilding device_lists_outbound_last_success with unique index") + execute_statements_from_stream(cur, StringIO(_rebuild_commands)) + + +# there might be duplicates, so the easiest way to achieve this is to create a new +# table with the right data, and renaming it into place + +_rebuild_commands = """ +DROP TABLE IF EXISTS device_lists_outbound_last_success_new; + +CREATE TABLE device_lists_outbound_last_success_new ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +-- this took about 30 seconds on matrix.org's 16 million rows. +INSERT INTO device_lists_outbound_last_success_new + SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success + GROUP BY destination, user_id; + +-- and this another 30 seconds. +CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx + ON device_lists_outbound_last_success_new (destination, user_id); + +DROP TABLE device_lists_outbound_last_success; + +ALTER TABLE device_lists_outbound_last_success_new + RENAME TO device_lists_outbound_last_success; +""" diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql index 883fcd10b2..883fcd10b2 100644 --- a/synapse/storage/schema/full_schemas/16/application_services.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql index 10ce2aa7a0..10ce2aa7a0 100644 --- a/synapse/storage/schema/full_schemas/16/event_edges.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql index 95826da431..95826da431 100644 --- a/synapse/storage/schema/full_schemas/16/event_signatures.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql index a1a2aa8e5b..a1a2aa8e5b 100644 --- a/synapse/storage/schema/full_schemas/16/im.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql index 11cdffdbb3..11cdffdbb3 100644 --- a/synapse/storage/schema/full_schemas/16/keys.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql index 8f3759bb2a..8f3759bb2a 100644 --- a/synapse/storage/schema/full_schemas/16/media_repository.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql index 01d2d8f833..01d2d8f833 100644 --- a/synapse/storage/schema/full_schemas/16/presence.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql index c04f4747d9..c04f4747d9 100644 --- a/synapse/storage/schema/full_schemas/16/profiles.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql index e44465cf45..e44465cf45 100644 --- a/synapse/storage/schema/full_schemas/16/push.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql index 318f0d9aa5..318f0d9aa5 100644 --- a/synapse/storage/schema/full_schemas/16/redactions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql index d47da3b12f..d47da3b12f 100644 --- a/synapse/storage/schema/full_schemas/16/room_aliases.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql index 96391a8f0e..96391a8f0e 100644 --- a/synapse/storage/schema/full_schemas/16/state.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql index 17e67bedac..17e67bedac 100644 --- a/synapse/storage/schema/full_schemas/16/transactions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql index f013aa8b18..f013aa8b18 100644 --- a/synapse/storage/schema/full_schemas/16/users.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres index 098434356f..889a9a0ce4 100644 --- a/synapse/storage/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres @@ -70,15 +70,6 @@ CREATE TABLE appservice_stream_position ( ); - -CREATE TABLE background_updates ( - update_name text NOT NULL, - progress_json text NOT NULL, - depends_on text -); - - - CREATE TABLE blocked_rooms ( room_id text NOT NULL, user_id text NOT NULL @@ -984,40 +975,6 @@ CREATE TABLE state_events ( -CREATE TABLE state_group_edges ( - state_group bigint NOT NULL, - prev_state_group bigint NOT NULL -); - - - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - - -CREATE TABLE state_groups ( - id bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE state_groups_state ( - state_group bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text NOT NULL -); - - - CREATE TABLE stats_stream_pos ( lock character(1) DEFAULT 'X'::bpchar NOT NULL, stream_id bigint, @@ -1202,11 +1159,6 @@ ALTER TABLE ONLY appservice_stream_position -ALTER TABLE ONLY background_updates - ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name); - - - ALTER TABLE ONLY current_state_events ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id); @@ -1496,12 +1448,6 @@ ALTER TABLE ONLY state_events ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - -ALTER TABLE ONLY state_groups - ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); - - - ALTER TABLE ONLY stats_stream_pos ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); @@ -1942,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); -CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); - - - -CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); - - - -CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); - - - CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); @@ -2047,6 +1981,3 @@ CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_room CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id); - - - diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite index be9295e4c9..a0411ede7e 100644 --- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite @@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); @@ -67,7 +65,6 @@ CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id ); CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) ); CREATE INDEX user_threepids_user_id ON user_threepids(user_id); -CREATE TABLE background_updates( update_name TEXT NOT NULL, progress_json TEXT NOT NULL, depends_on TEXT, CONSTRAINT background_updates_uniqueness UNIQUE (update_name) ); CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) /* event_search(event_id,room_id,sender,"key",value) */; CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); @@ -121,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); @@ -255,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); CREATE INDEX users_creation_ts ON users (creation_ts); CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql index c265fd20e2..91d21b2921 100644 --- a/synapse/storage/schema/full_schemas/54/stream_positions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql @@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); INSERT INTO stats_stream_pos (stream_id) VALUES (0); INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); +-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md new file mode 100644 index 0000000000..c00f287190 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md @@ -0,0 +1,21 @@ +# Synapse Database Schemas + +These schemas are used as a basis to create brand new Synapse databases, on both +SQLite3 and Postgres. + +## Building full schema dumps + +If you want to recreate these schemas, they need to be made from a database that +has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce new +`full.sql.postgres ` and `full.sql.sqlite` files. + +Ensure postgres is installed and your user has the ability to run bash commands +such as `createdb`, then call + + ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ + +There are currently two folders with full-schema snapshots. `16` is a snapshot +from 2015, for historical reference. The other contains the most recent full +schema snapshot. diff --git a/synapse/storage/search.py b/synapse/storage/data_stores/main/search.py index df87ab6a6d..13f49d8060 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -24,10 +24,11 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from .background_updates import BackgroundUpdateStore - logger = logging.getLogger(__name__) SearchEntry = namedtuple( @@ -36,23 +37,71 @@ SearchEntry = namedtuple( ) -class SearchStore(BackgroundUpdateStore): +class SearchWorkerStore(SQLBaseStore): + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if not self.hs.config.enable_search: + return + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) + + txn.executemany(sql, args) + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + +class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -61,9 +110,11 @@ class SearchStore(BackgroundUpdateStore): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + self.db.updates.register_noop_background_update( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME + ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) @@ -93,7 +144,7 @@ class SearchStore(BackgroundUpdateStore): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -153,18 +204,18 @@ class SearchStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(event_search_rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) return len(event_search_rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) if not result: - yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) return result @@ -196,7 +247,7 @@ class SearchStore(BackgroundUpdateStore): " ON event_search USING GIN (vector)" ) except psycopg2.ProgrammingError as e: - logger.warn( + logger.warning( "Ignoring error %r when trying to switch from GIST to GIN", e ) @@ -206,9 +257,11 @@ class SearchStore(BackgroundUpdateStore): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME + ) return 1 @defer.inlineCallbacks @@ -237,14 +290,14 @@ class SearchStore(BackgroundUpdateStore): ) conn.set_session(autocommit=False) - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.runInteraction( + yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -274,89 +327,27 @@ class SearchStore(BackgroundUpdateStore): "have_added_indexes": True, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) return len(rows), True - num_rows, finished = yield self.runInteraction( + num_rows, finished = yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) if not finished: - yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) - - return num_rows - - def store_event_search_txn(self, txn, event, key, value): - """Add event to the search table - - Args: - txn (cursor): - event (EventBase): - key (str): - value (str): - """ - self.store_search_entries_txn( - txn, - ( - SearchEntry( - key=key, - value=value, - event_id=event.event_id, - room_id=event.room_id, - stream_ordering=event.internal_metadata.stream_ordering, - origin_server_ts=event.origin_server_ts, - ), - ), - ) - - def store_search_entries_txn(self, txn, entries): - """Add entries to the search table - - Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table - """ - if not self.hs.config.enable_search: - return - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_ORDER_UPDATE_NAME ) - args = ( - ( - entry.event_id, - entry.room_id, - entry.key, - entry.value, - entry.stream_ordering, - entry.origin_server_ts, - ) - for entry in entries - ) - - txn.executemany(sql, args) + return num_rows - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) - for entry in entries - ) - txn.executemany(sql, args) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") +class SearchStore(SearchBackgroundUpdateStore): + def __init__(self, database: Database, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): @@ -373,15 +364,17 @@ class SearchStore(BackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) - args.extend(room_ids) + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] local_clauses = [] for key in keys: @@ -434,11 +427,18 @@ class SearchStore(BackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_msgs", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} @@ -448,8 +448,8 @@ class SearchStore(BackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -480,15 +480,17 @@ class SearchStore(BackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) - args.extend(room_ids) + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] local_clauses = [] for key in keys: @@ -577,11 +579,18 @@ class SearchStore(BackgroundUpdateStore): args.append(limit) - results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_rooms", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} @@ -591,8 +600,8 @@ class SearchStore(BackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -663,7 +672,7 @@ class SearchStore(BackgroundUpdateStore): ) ) txn.execute(query, (value, search_query)) - headline, = txn.fetchall()[0] + (headline,) = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline # result. @@ -677,7 +686,7 @@ class SearchStore(BackgroundUpdateStore): return highlight_words - return self.runInteraction("_find_highlights", f) + return self.db.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/signatures.py b/synapse/storage/data_stores/main/signatures.py index fb83218f90..36244d9f5d 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -13,24 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six - from unpaddedbase64 import encode_base64 from twisted.internet import defer -from synapse.crypto.event_signing import compute_event_reference_hash +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList -from ._base import SQLBaseStore - -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class SignatureWorkerStore(SQLBaseStore): @cached() @@ -49,7 +38,7 @@ class SignatureWorkerStore(SQLBaseStore): for event_id in event_ids } - return self.runInteraction("get_event_reference_hashes", f) + return self.db.runInteraction("get_event_reference_hashes", f) @defer.inlineCallbacks def add_event_hashes(self, event_ids): @@ -80,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore): class SignatureStore(SignatureWorkerStore): """Persistence for event signatures and hashes""" - - def _store_event_reference_hashes_txn(self, txn, events): - """Store a hash for a PDU - Args: - txn (cursor): - events (list): list of Events. - """ - - vals = [] - for event in events: - ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append( - { - "event_id": event.event_id, - "algorithm": ref_alg, - "hash": db_binary_type(ref_hash_bytes), - } - ) - - self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py new file mode 100644 index 0000000000..347cc50778 --- /dev/null +++ b/synapse/storage/data_stores/main/state.py @@ -0,0 +1,467 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections.abc +import logging +from collections import namedtuple + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database +from synapse.storage.state import StateFilter +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached, cachedList + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +# this inherits from EventsWorkerStore because it calls self.get_events +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) + + async def get_room_version(self, room_id: str) -> RoomVersion: + """Get the room_version of a given room + + Raises: + NotFoundError: if the room is unknown + + UnsupportedRoomVersionError: if the room uses an unknown room version. + Typically this happens if support for the room's version has been + removed from Synapse. + """ + room_version_id = await self.get_room_version_id(room_id) + v = KNOWN_ROOM_VERSIONS.get(room_version_id) + + if not v: + raise UnsupportedRoomVersionError( + "Room %s uses a room version %s which is no longer supported" + % (room_id, room_version_id) + ) + + return v + + @cached(max_entries=10000) + async def get_room_version_id(self, room_id: str) -> str: + """Get the room_version of a given room + + Raises: + NotFoundError: if the room is unknown + """ + + # First we try looking up room version from the database, but for old + # rooms we might not have added the room version to it yet so we fall + # back to previous behaviour and look in current state events. + + # We really should have an entry in the rooms table for every room we + # care about, but let's be a bit paranoid (at least while the background + # update is happening) to avoid breaking existing rooms. + version = await self.db.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="room_version", + desc="get_room_version", + allow_none=True, + ) + + if version is not None: + return version + + # Retrieve the room's create event + create_event = await self.get_create_event_for_room(room_id) + return create_event.content.get("room_version", "1") + + @defer.inlineCallbacks + def get_room_predecessor(self, room_id): + """Get the predecessor of an upgraded room if it exists. + Otherwise return None. + + Args: + room_id (str) + + Returns: + Deferred[dict|None]: A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room + + None if a predecessor key is not found, or is not a dictionary. + + Raises: + NotFoundError if the given room is unknown + """ + # Retrieve the room's create event + create_event = yield self.get_create_event_for_room(room_id) + + # Retrieve the predecessor key of the create event + predecessor = create_event.content.get("predecessor", None) + + # Ensure the key is a dictionary + if not isinstance(predecessor, collections.abc.Mapping): + return None + + return predecessor + + @defer.inlineCallbacks + def get_create_event_for_room(self, room_id): + """Get the create state event for a room. + + Args: + room_id (str) + + Returns: + Deferred[EventBase]: The room creation event. + + Raises: + NotFoundError if the room is unknown + """ + state_ids = yield self.get_current_state_ids(room_id) + create_id = state_ids.get((EventTypes.Create, "")) + + # If we can't find the create event, assume we've hit a dead end + if not create_id: + raise NotFoundError("Unknown room %s" % (room_id,)) + + # Retrieve the room's create event and return + create_event = yield self.get_event(create_id) + return create_event + + @cached(max_entries=100000, iterable=True) + def get_current_state_ids(self, room_id): + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,), + ) + + return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} + + return self.db.runInteraction( + "get_current_state_ids", _get_current_state_ids_txn + ) + + # FIXME: how should this be cached? + def get_filtered_current_state_ids( + self, room_id: str, state_filter: StateFilter = StateFilter.all() + ): + """Get the current state event of a given type for a room based on the + current_state_events table. This may not be as up-to-date as the result + of doing a fresh state resolution as per state_handler.get_current_state + + Args: + room_id + state_filter: The state filter used to fetch state + from the database. + + Returns: + defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if not where_clause: + # We delegate to the cached version + return self.get_current_state_ids(room_id) + + def _get_filtered_current_state_ids_txn(txn): + results = {} + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + + return results + + return self.db.runInteraction( + "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn + ) + + @defer.inlineCallbacks + def get_canonical_alias_for_room(self, room_id): + """Get canonical alias for room, if any + + Args: + room_id (str) + + Returns: + Deferred[str|None]: The canonical alias, if any + """ + + state = yield self.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return + + event = yield self.get_event(event_id, allow_none=True) + if not event: + return + + return event.content.get("canonical_alias") + + @cached(max_entries=50000) + def _get_state_group_for_event(self, event_id): + return self.db.simple_select_one_onecol( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + retcol="state_group", + allow_none=True, + desc="_get_state_group_for_event", + ) + + @cachedList( + cached_method_name="_get_state_group_for_event", + list_name="event_ids", + num_args=1, + inlineCallbacks=True, + ) + def _get_state_group_for_events(self, event_ids): + """Returns mapping event_id -> state_group + """ + rows = yield self.db.simple_select_many_batch( + table="event_to_state_groups", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id", "state_group"), + desc="_get_state_group_for_events", + ) + + return {row["event_id"]: row["state_group"] for row in rows} + + @defer.inlineCallbacks + def get_referenced_state_groups(self, state_groups): + """Check if the state groups are referenced by events. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[set[int]]: The subset of state groups that are + referenced. + """ + + rows = yield self.db.simple_select_many_batch( + table="event_to_state_groups", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("DISTINCT state_group",), + desc="get_referenced_state_groups", + ) + + return {row["state_group"] for row in rows} + + +class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): + + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" + DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" + + def __init__(self, database: Database, db_conn, hs): + super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.server_name = hs.hostname + + self.db.updates.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + self.db.updates.register_background_index_update( + self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, + index_name="event_to_state_groups_sg_index", + table="event_to_state_groups", + columns=["state_group"], + ) + self.db.updates.register_background_update_handler( + self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, + ) + + async def _background_remove_left_rooms(self, progress, batch_size): + """Background update to delete rows from `current_state_events` and + `event_forward_extremities` tables of rooms that the server is no + longer joined to. + """ + + last_room_id = progress.get("last_room_id", "") + + def _background_remove_left_rooms_txn(txn): + sql = """ + SELECT DISTINCT room_id FROM current_state_events + WHERE room_id > ? ORDER BY room_id LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + room_ids = [row[0] for row in txn] + if not room_ids: + return True, set() + + sql = """ + SELECT room_id + FROM current_state_events + WHERE + room_id > ? AND room_id <= ? + AND type = 'm.room.member' + AND membership = 'join' + AND state_key LIKE ? + GROUP BY room_id + """ + + txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) + + joined_room_ids = {row[0] for row in txn} + + left_rooms = set(room_ids) - joined_room_ids + + logger.info("Deleting current state left rooms: %r", left_rooms) + + # First we get all users that we still think were joined to the + # room. This is so that we can mark those device lists as + # potentially stale, since there may have been a period where the + # server didn't share a room with the remote user and therefore may + # have missed any device updates. + rows = self.db.simple_select_many_txn( + txn, + table="current_state_events", + column="room_id", + iterable=left_rooms, + keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN}, + retcols=("state_key",), + ) + + potentially_left_users = {row["state_key"] for row in rows} + + # Now lets actually delete the rooms from the DB. + self.db.simple_delete_many_txn( + txn, + table="current_state_events", + column="room_id", + iterable=left_rooms, + keyvalues={}, + ) + + self.db.simple_delete_many_txn( + txn, + table="event_forward_extremities", + column="room_id", + iterable=left_rooms, + keyvalues={}, + ) + + self.db.updates._background_update_progress_txn( + txn, + self.DELETE_CURRENT_STATE_UPDATE_NAME, + {"last_room_id": room_ids[-1]}, + ) + + return False, potentially_left_users + + finished, potentially_left_users = await self.db.runInteraction( + "_background_remove_left_rooms", _background_remove_left_rooms_txn + ) + + if finished: + await self.db.updates._end_background_update( + self.DELETE_CURRENT_STATE_UPDATE_NAME + ) + + # Now go and check if we still share a room with the remote users in + # the deleted rooms. If not mark their device lists as stale. + joined_users = await self.get_users_server_still_shares_room_with( + potentially_left_users + ) + + for user_id in potentially_left_users - joined_users: + await self.mark_remote_user_device_list_as_unsubscribed(user_id) + + return batch_size + + +class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 5fdb442104..725e12507f 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -15,13 +15,15 @@ import logging +from twisted.internet import defer + from synapse.storage._base import SQLBaseStore logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id): + def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): """Fetch a list of room state changes since the given stream id Each entry in the result contains the following fields: @@ -36,15 +38,27 @@ class StateDeltasStore(SQLBaseStore): Args: prev_stream_id (int): point to get changes since (exclusive) + max_stream_id (int): the point that we know has been correctly persisted + - ie, an upper limit to return changes from. Returns: - Deferred[list[dict]]: results + Deferred[tuple[int, list[dict]]: A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. """ prev_stream_id = int(prev_stream_id) + + # check we're not going backwards + assert prev_stream_id <= max_stream_id + if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id ): - return [] + # if the CSDs haven't changed between prev_stream_id and now, we + # know for certain that they haven't changed between prev_stream_id and + # max_stream_id. + return defer.succeed((max_stream_id, [])) def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than @@ -54,21 +68,29 @@ class StateDeltasStore(SQLBaseStore): sql = """ SELECT stream_id, count(*) FROM current_state_delta_stream - WHERE stream_id > ? + WHERE stream_id > ? AND stream_id <= ? GROUP BY stream_id ORDER BY stream_id ASC LIMIT 100 """ - txn.execute(sql, (prev_stream_id,)) + txn.execute(sql, (prev_stream_id, max_stream_id)) total = 0 - max_stream_id = prev_stream_id - for max_stream_id, count in txn: + + for stream_id, count in txn: total += count if total > 100: # We arbitarily limit to 100 entries to ensure we don't # select toooo many. + logger.debug( + "Clipping current_state_delta_stream rows to stream_id %i", + stream_id, + ) + clipped_stream_id = stream_id break + else: + # if there's no problem, we may as well go right up to the max_stream_id + clipped_stream_id = max_stream_id # Now actually get the deltas sql = """ @@ -77,15 +99,15 @@ class StateDeltasStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - return self.cursor_to_dict(txn) + txn.execute(sql, (prev_stream_id, clipped_stream_id)) + return clipped_stream_id, self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self._simple_select_one_onecol_txn( + return self.db.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, @@ -93,7 +115,7 @@ class StateDeltasStore(SQLBaseStore): ) def get_max_stream_id_in_current_state_deltas(self): - return self.runInteraction( + return self.db.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/stats.py b/synapse/storage/data_stores/main/stats.py index 09190d684e..380c1ec7da 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -21,8 +21,9 @@ from twisted.internet import defer from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership -from synapse.storage import PostgresEngine -from synapse.storage.state_deltas import StateDeltasStore +from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() @@ -68,17 +69,17 @@ class StatsStore(StateDeltasStore): self.stats_delta_processing_lock = DeferredLock() - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.register_noop_background_update("populate_stats_cleanup") - self.register_noop_background_update("populate_stats_prepare") + self.db.updates.register_noop_background_update("populate_stats_cleanup") + self.db.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -102,7 +103,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -117,22 +118,22 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: yield self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.runInteraction( + yield self.db.runInteraction( "populate_stats_process_users", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) @@ -145,7 +146,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -160,22 +161,22 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: yield self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.runInteraction( + yield self.db.runInteraction( "_populate_stats_process_rooms", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", progress, ) @@ -186,7 +187,7 @@ class StatsStore(StateDeltasStore): """ Returns the stats processor positions. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +216,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self._simple_upsert( + return self.db.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -236,7 +237,7 @@ class StatsStore(StateDeltasStore): Deferred[list[dict]], where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.runInteraction( + return self.db.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -257,44 +258,19 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self._simple_select_list_paginate_txn( + slice_list = self.db.simple_select_list_paginate_txn( txn, table + "_historical", - {id_col: stats_id}, "end_ts", start, size, retcols=selected_columns + ["bucket_size", "end_ts"], + keyvalues={id_col: stats_id}, order_direction="DESC", ) return slice_list - def get_room_stats_state(self, room_id): - """ - Returns the current room_stats_state for a room. - - Args: - room_id (str): The ID of the room to return state for. - - Returns (dict): - Dictionary containing these keys: - "name", "topic", "canonical_alias", "avatar", "join_rules", - "history_visibility" - """ - return self._simple_select_one( - "room_stats_state", - {"room_id": room_id}, - retcols=( - "name", - "topic", - "canonical_alias", - "avatar", - "join_rules", - "history_visibility", - ), - ) - @cached() def get_earliest_token_for_stats(self, stats_type, id): """ @@ -308,7 +284,7 @@ class StatsStore(StateDeltasStore): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -332,6 +308,9 @@ class StatsStore(StateDeltasStore): def _bulk_update_stats_delta_txn(txn): for stats_type, stats_updates in updates.items(): for stats_id, fields in stats_updates.items(): + logger.debug( + "Updating %s stats for %s: %s", stats_type, stats_id, fields + ) self._update_stats_delta_txn( txn, ts=ts, @@ -341,14 +320,14 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": stream_id}, ) - return self.runInteraction( + return self.db.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) @@ -379,7 +358,7 @@ class StatsStore(StateDeltasStore): Does not work with per-slice fields. """ - return self.runInteraction( + return self.db.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -514,17 +493,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self._simple_select_one_txn( + current_row = self.db.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self._simple_insert_txn(txn, table, merged_dict) + self.db.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self._simple_update_one_txn(txn, table, keyvalues, current_row) + self.db.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -611,11 +590,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self._simple_select_one_txn( + src_row = self.db.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self._simple_select_one_txn( + dest_current_row = self.db.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -631,11 +610,11 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self._simple_insert_txn(txn, into_table, merged_dict) + self.db.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -649,7 +628,7 @@ class StatsStore(StateDeltasStore): changes. """ - return self.runInteraction( + return self.db.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, @@ -732,7 +711,7 @@ class StatsStore(StateDeltasStore): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="current_state_events", column="type", @@ -740,7 +719,7 @@ class StatsStore(StateDeltasStore): EventTypes.Create, EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, - EventTypes.Encryption, + EventTypes.RoomEncryption, EventTypes.Name, EventTypes.Topic, EventTypes.RoomAvatar, @@ -770,7 +749,7 @@ class StatsStore(StateDeltasStore): (room_id,), ) - current_state_events_count, = txn.fetchone() + (current_state_events_count,) = txn.fetchone() users_in_room = self.get_users_in_room_txn(txn, room_id) @@ -788,7 +767,7 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.runInteraction( + ) = yield self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) @@ -812,7 +791,7 @@ class StatsStore(StateDeltasStore): room_state["history_visibility"] = event.content.get( "history_visibility" ) - elif event.type == EventTypes.Encryption: + elif event.type == EventTypes.RoomEncryption: room_state["encryption"] = event.content.get("algorithm") elif event.type == EventTypes.Name: room_state["name"] = event.content.get("name") @@ -860,10 +839,10 @@ class StatsStore(StateDeltasStore): """, (user_id,), ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.runInteraction( + joined_rooms, pos = yield self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) diff --git a/synapse/storage/stream.py b/synapse/storage/data_stores/main/stream.py index 490454f19a..e89f0bffb5 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,8 +46,9 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine -from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -229,6 +233,14 @@ def filter_to_clause(event_filter): clauses.append("contains_url = ?") args.append(event_filter.contains_url) + # We're only applying the "labels" filter on the database query, because applying the + # "not_labels" filter via a SQL query is non-trivial. Instead, we let + # event_filter.check_fields apply it, which is not as efficient but makes the + # implementation simpler. + if event_filter.labels: + clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels)) + args.extend(event_filter.labels) + return " AND ".join(clauses), args @@ -240,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(StreamWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StreamWorkerStore, self).__init__(database, db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self._get_cache_dict( + event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -334,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream - return set( + return { room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) - ) + } @defer.inlineCallbacks def get_room_events_stream_for_room( @@ -389,7 +401,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.runInteraction("get_room_events_stream_for_room", f) + rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -439,7 +451,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.runInteraction("get_membership_changes_for_user", f) + rows = yield self.db.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -469,11 +481,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id, limit, end_token ) - logger.debug("stream before") events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) - logger.debug("stream after") self._set_before_and_after(events, rows) @@ -500,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -513,8 +523,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_after_stream_ordering(self, room_id, stream_ordering): - """Gets details of the first event in a room at or after a stream ordering + def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or before a stream ordering Args: room_id (str): @@ -529,15 +539,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" - " WHERE room_id = ? AND stream_ordering >= ?" + " WHERE room_id = ? AND stream_ordering <= ?" " AND NOT outlier" - " ORDER BY stream_ordering" + " ORDER BY stream_ordering DESC" " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.runInteraction("get_room_event_after_stream_ordering", _f) + return self.db.runInteraction("get_room_event_before_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): @@ -551,7 +561,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if room_id is None: return "s%d" % (token,) else: - topo = yield self.runInteraction( + topo = yield self.db.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) @@ -565,7 +575,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "s%d" stream token. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -578,7 +588,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "t%d-%d" topological token. """ - return self._simple_select_one( + return self.db.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -602,13 +612,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self._execute( + return self.db.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) def _get_max_topological_txn(self, txn, room_id): txn.execute( - "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", + "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id,), ) @@ -656,7 +666,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -667,11 +677,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) events_before = yield self.get_events_as_list( - [e for e in results["before"]["event_ids"]], get_prev_content=True + list(results["before"]["event_ids"]), get_prev_content=True ) events_after = yield self.get_events_as_list( - [e for e in results["after"]["event_ids"]], get_prev_content=True + list(results["after"]["event_ids"]), get_prev_content=True ) return { @@ -698,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self._simple_select_one_txn( + results = self.db.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -777,7 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) @@ -786,7 +796,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events def get_federation_out_pos(self, typ): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -794,7 +804,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def update_federation_out_pos(self, typ, stream_id): - return self._simple_update_one( + return self.db.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, @@ -863,13 +873,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args.append(int(limit)) - sql = ( - "SELECT event_id, topological_ordering, stream_ordering" - " FROM events" - " WHERE outlier = ? AND room_id = ? AND %(bounds)s" - " ORDER BY topological_ordering %(order)s," - " stream_ordering %(order)s LIMIT ?" - ) % {"bounds": bounds, "order": order} + select_keywords = "SELECT" + join_clause = "" + if event_filter and event_filter.labels: + # If we're not filtering on a label, then joining on event_labels will + # return as many row for a single event as the number of labels it has. To + # avoid this, only join if we're filtering on at least one label. + join_clause = """ + LEFT JOIN event_labels + USING (event_id, room_id, topological_ordering) + """ + if len(event_filter.labels) > 1: + # Using DISTINCT in this SELECT query is quite expensive, because it + # requires the engine to sort on the entire (not limited) result set, + # i.e. the entire events table. We only need to use it when we're + # filtering on more than two labels, because that's the only scenario + # in which we can possibly to get multiple times the same event ID in + # the results. + select_keywords += "DISTINCT" + + sql = """ + %(select_keywords)s event_id, topological_ordering, stream_ordering + FROM events + %(join_clause)s + WHERE outlier = ? AND room_id = ? AND %(bounds)s + ORDER BY topological_ordering %(order)s, + stream_ordering %(order)s LIMIT ? + """ % { + "select_keywords": select_keywords, + "join_clause": join_clause, + "bounds": bounds, + "order": order, + } txn.execute(sql, args) @@ -920,7 +955,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, diff --git a/synapse/storage/tags.py b/synapse/storage/data_stores/main/tags.py index 20dd6bd53d..4219018302 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -22,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.storage.account_data import AccountDataWorkerStore +from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tag strings to tag content. """ - deferred = self._simple_select_list( + deferred = self.db.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -78,14 +78,12 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = yield self.runInteraction( + tag_ids = yield self.db.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) def get_tag_content(txn, tag_ids): - sql = ( - "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?" - ) + sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" results = [] for stream_id, user_id, room_id in tag_ids: txn.execute(sql, (user_id, room_id)) @@ -100,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore): batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = yield self.runInteraction( + tags = yield self.db.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], @@ -137,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) + room_ids = yield self.db.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) results = {} if room_ids: @@ -155,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A deferred list of string tags. """ - return self._simple_select_list( + return self.db.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -180,7 +180,7 @@ class TagsStore(TagsWorkerStore): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, @@ -189,7 +189,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("add_tag", add_tag_txn, next_id) + yield self.db.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -212,7 +212,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -233,6 +233,9 @@ class TagsStore(TagsWorkerStore): self._account_data_stream_cache.entity_has_changed, user_id, next_id ) + # Note: This is only here for backwards compat to allow admins to + # roll back to a previous Synapse version. Next time we update the + # database version we can remove this table. update_max_id_sql = ( "UPDATE account_data_max_stream_id" " SET stream_id = ?" diff --git a/synapse/storage/transactions.py b/synapse/storage/data_stores/main/transactions.py index 289c117396..a9bf457939 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -16,23 +16,16 @@ import logging from collections import namedtuple -import six - from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache -from ._base import SQLBaseStore, db_to_json - -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview +db_binary_type = memoryview logger = logging.getLogger(__name__) @@ -53,8 +46,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) @@ -78,7 +71,7 @@ class TransactionStore(SQLBaseStore): this transaction or a 2-tuple of (int, dict) """ - return self.runInteraction( + return self.db.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -86,7 +79,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -120,7 +113,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self._simple_insert( + return self.db.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -149,7 +142,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.runInteraction( + result = yield self.db.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, @@ -161,7 +154,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -188,7 +181,7 @@ class TransactionStore(SQLBaseStore): """ self._destination_retry_cache.pop(destination, None) - return self.runInteraction( + return self.db.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -228,7 +221,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self._simple_select_one_txn( + prev_row = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -237,7 +230,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="destinations", values={ @@ -248,7 +241,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, @@ -271,4 +264,6 @@ class TransactionStore(SQLBaseStore): def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) + return self.db.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py new file mode 100644 index 0000000000..1d8ee22fb1 --- /dev/null +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Dict, Optional, Union + +import attr + +import synapse.util.stringutils as stringutils +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.types import JsonDict + + +@attr.s +class UIAuthSessionData: + session_id = attr.ib(type=str) + # The dictionary from the client root level, not the 'auth' key. + clientdict = attr.ib(type=JsonDict) + # The URI and method the session was intiatied with. These are checked at + # each stage of the authentication to ensure that the asked for operation + # has not changed. + uri = attr.ib(type=str) + method = attr.ib(type=str) + # A string description of the operation that the current authentication is + # authorising. + description = attr.ib(type=str) + + +class UIAuthWorkerStore(SQLBaseStore): + """ + Manage user interactive authentication sessions. + """ + + async def create_ui_auth_session( + self, clientdict: JsonDict, uri: str, method: str, description: str, + ) -> UIAuthSessionData: + """ + Creates a new user interactive authentication session. + + The session can be used to track the stages necessary to authenticate a + user across multiple HTTP requests. + + Args: + clientdict: + The dictionary from the client root level, not the 'auth' key. + uri: + The URI this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + method: + The method this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + Raises: + StoreError if a unique session ID cannot be generated. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db.simple_insert( + table="ui_auth_sessions", + values={ + "session_id": session_id, + "clientdict": clientdict_json, + "uri": uri, + "method": method, + "description": description, + "serverdict": "{}", + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_ui_auth_session", + ) + return UIAuthSessionData( + session_id, clientdict, uri, method, description + ) + except self.db.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: + """Retrieve a UI auth session. + + Args: + session_id: The ID of the session. + Returns: + A dict containing the device information. + Raises: + StoreError if the session is not found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("clientdict", "uri", "method", "description"), + desc="get_ui_auth_session", + ) + + result["clientdict"] = json.loads(result["clientdict"]) + + return UIAuthSessionData(session_id, **result) + + async def mark_ui_auth_stage_complete( + self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + ): + """ + Mark a session stage as completed. + + Args: + session_id: The ID of the corresponding session. + stage_type: The completed stage type. + result: The result of the stage verification. + Raises: + StoreError if the session cannot be found. + """ + # Add (or update) the results of the current stage to the database. + # + # Note that we need to allow for the same stage to complete multiple + # times here so that registration is idempotent. + try: + await self.db.simple_upsert( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id, "stage_type": stage_type}, + values={"result": json.dumps(result)}, + desc="mark_ui_auth_stage_complete", + ) + except self.db.engine.module.IntegrityError: + raise StoreError(400, "Unknown session ID: %s" % (session_id,)) + + async def get_completed_ui_auth_stages( + self, session_id: str + ) -> Dict[str, Union[str, bool, JsonDict]]: + """ + Retrieve the completed stages of a UI authentication session. + + Args: + session_id: The ID of the session. + Returns: + The completed stages mapped to the result of the verification of + that auth-type. + """ + results = {} + for row in await self.db.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ): + results[row["stage_type"]] = json.loads(row["result"]) + + return results + + async def set_ui_auth_clientdict( + self, session_id: str, clientdict: JsonDict + ) -> None: + """ + Store an updated clientdict for a given session ID. + + Args: + session_id: The ID of this session as returned from check_auth + clientdict: + The dictionary from the client root level, not the 'auth' key. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + self.db.simple_update_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"clientdict": clientdict_json}, + desc="set_ui_auth_client_dict", + ) + + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + """ + Store a key-value pair into the sessions data associated with this + request. This data is stored server-side and cannot be modified by + the client. + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store + Raises: + StoreError if the session cannot be found. + """ + await self.db.runInteraction( + "set_ui_auth_session_data", + self._set_ui_auth_session_data_txn, + session_id, + key, + value, + ) + + def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + # Get the current value. + result = self.db.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ) + + # Update it and add it back to the database. + serverdict = json.loads(result["serverdict"]) + serverdict[key] = value + + self.db.simple_update_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"serverdict": json.dumps(serverdict)}, + ) + + async def get_ui_auth_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: + """ + Retrieve data stored with set_session_data + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set + Raises: + StoreError if the session cannot be found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + desc="get_ui_auth_session_data", + ) + + serverdict = json.loads(result["serverdict"]) + + return serverdict.get(key, default) + + +class UIAuthStore(UIAuthWorkerStore): + def delete_old_ui_auth_sessions(self, expiration_time: int): + """ + Remove sessions which were last used earlier than the expiration time. + + Args: + expiration_time: The latest time that is still considered valid. + This is an epoch time in milliseconds. + + """ + return self.db.runInteraction( + "delete_old_ui_auth_sessions", + self._delete_old_ui_auth_sessions_txn, + expiration_time, + ) + + def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + # Get the expired sessions. + sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" + txn.execute(sql, [expiration_time]) + session_ids = [r[0] for r in txn.fetchall()] + + # Delete the corresponding completed credentials. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + + # Finally, delete the sessions. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) diff --git a/synapse/storage/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index b5188d9bee..6b8130bf0f 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -19,10 +19,10 @@ import re from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.data_stores.main.state import StateFilter +from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.storage.state import StateFilter -from synapse.storage.state_deltas import StateDeltasStore from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -32,30 +32,30 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): +class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) @@ -85,7 +85,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,15 +100,17 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - yield self._end_background_update("populate_user_directory_createtables") + yield self.db.updates._end_background_update( + "populate_user_directory_createtables" + ) return 1 @defer.inlineCallbacks @@ -116,7 +118,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): """ Update the user directory stream position, then clean up the old tables. """ - position = yield self._simple_select_one_onecol( + position = yield self.db.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -126,11 +128,11 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) - yield self._end_background_update("populate_user_directory_cleanup") + yield self.db.updates._end_background_update("populate_user_directory_cleanup") return 1 @defer.inlineCallbacks @@ -170,16 +172,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): return rooms_to_work_on - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_user_directory_process_rooms") + yield self.db.updates._end_background_update( + "populate_user_directory_process_rooms" + ) return 1 - logger.info( + logger.debug( "Processing the next %d rooms of %d remaining" % (len(rooms_to_work_on), progress["remaining"]) ) @@ -243,12 +247,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): to_insert.clear() # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -267,7 +271,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 def _get_next_batch(txn): @@ -291,16 +297,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): return users_to_work_on - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more users -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 - logger.info( + logger.debug( "Processing the next %d users of %d remaining" % (len(users_to_work_on), progress["remaining"]) ) @@ -312,12 +320,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): ) # We've finished processing a user. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) @@ -361,7 +369,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): """ def _update_profile_in_user_dir_txn(txn): - new_entry = self._simple_upsert_txn( + new_entry = self.db.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +443,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -448,59 +456,10 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction( + return self.db.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) - def remove_from_user_dir(self, user_id): - def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( - txn, table="user_directory", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, table="user_directory_search", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"other_user_id": user_id}, - ) - txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - - return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) - - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): - """Get all user_ids that are in the room directory because they're - in the given room_id - """ - user_ids_share_pub = yield self._simple_select_onecol( - table="users_in_public_rooms", - keyvalues={"room_id": room_id}, - retcol="user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids_share_priv = yield self._simple_select_onecol( - table="users_who_share_private_rooms", - keyvalues={"room_id": room_id}, - retcol="other_user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids = set(user_ids_share_pub) - user_ids.update(user_ids_share_priv) - - return user_ids - def add_users_who_share_private_room(self, room_id, user_id_tuples): """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. @@ -511,7 +470,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): """ def _add_users_who_share_room_txn(txn): - self._simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -523,7 +482,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) @@ -538,7 +497,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): def _add_users_in_public_rooms_txn(txn): - self._simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -547,10 +506,102 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_public_rooms") + txn.execute("DELETE FROM users_who_share_private_rooms") + txn.call_after(self.get_user_in_directory.invalidate_all) + + return self.db.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self.db.simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self.db.simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + +class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): + + # How many records do we calculate before sending it to + # add_users_who_share_private_rooms? + SHARE_PRIVATE_WORKING_SET = 500 + + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self.db.simple_delete_txn( + txn, table="user_directory", keyvalues={"user_id": user_id} + ) + self.db.simple_delete_txn( + txn, table="user_directory_search", keyvalues={"user_id": user_id} + ) + self.db.simple_delete_txn( + txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} + ) + self.db.simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id}, + ) + self.db.simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"other_user_id": user_id}, + ) + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + + @defer.inlineCallbacks + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory because they're + in the given room_id + """ + user_ids_share_pub = yield self.db.simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_share_priv = yield self.db.simple_select_onecol( + table="users_who_share_private_rooms", + keyvalues={"room_id": room_id}, + retcol="other_user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids = set(user_ids_share_pub) + user_ids.update(user_ids_share_priv) + + return user_ids + def remove_user_who_share_room(self, user_id, room_id): """ Deletes entries in the users_who_share_*_rooms table. The first @@ -562,23 +613,23 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): """ def _remove_user_who_share_room_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) @@ -593,14 +644,14 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): Returns: list: user_id """ - rows = yield self._simple_select_onecol( + rows = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self._simple_select_onecol( + pub_rows = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -631,53 +682,20 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): ) f2 USING (room_id) """ - rows = yield self._execute( + rows = yield self.db.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] - def delete_all_from_user_dir(self): - """Delete the entire user directory - """ - - def _delete_all_from_user_dir_txn(txn): - txn.execute("DELETE FROM user_directory") - txn.execute("DELETE FROM user_directory_search") - txn.execute("DELETE FROM users_in_public_rooms") - txn.execute("DELETE FROM users_who_share_private_rooms") - txn.call_after(self.get_user_in_directory.invalidate_all) - - return self.runInteraction( - "delete_all_from_user_dir", _delete_all_from_user_dir_txn - ) - - @cached() - def get_user_in_directory(self, user_id): - return self._simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", - ) - def get_user_directory_stream_pos(self): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", desc="get_user_directory_stream_pos", ) - def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( - table="user_directory_stream_pos", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - desc="update_user_directory_stream_pos", - ) - @defer.inlineCallbacks def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory @@ -776,8 +794,8 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self._execute( - "search_user_dir", self.cursor_to_dict, sql, *args + results = yield self.db.execute( + "search_user_dir", self.db.cursor_to_dict, sql, *args ) limited = len(results) > limit diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index 05cabc2282..ec6b8a4ffd 100644 --- a/synapse/storage/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if the user has requested erasure """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,16 +56,16 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - def _get_erased_users(txn): - txn.execute( - "SELECT user_id FROM erased_users WHERE user_id IN (%s)" - % (",".join("?" * len(user_ids))), - user_ids, - ) - return set(r[0] for r in txn) - - erased_users = yield self.runInteraction("are_users_erased", _get_erased_users) - res = dict((u, u in erased_users) for u in user_ids) + rows = yield self.db.simple_select_many_batch( + table="erased_users", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="are_users_erased", + ) + erased_users = {row["user_id"] for row in rows} + + res = {u: u in erased_users for u in user_ids} return res @@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.runInteraction("mark_user_erased", f) + return self.db.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py new file mode 100644 index 0000000000..86e09f6229 --- /dev/null +++ b/synapse/storage/data_stores/state/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py new file mode 100644 index 0000000000..ff000bc9ec --- /dev/null +++ b/synapse/storage/data_stores/state/bg_updates.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) %s + ORDER BY type, state_key, state_group DESC + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self.db.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db.runWithConnection(reindex_txn) + + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql index ae09fa0065..ae09fa0065 100644 --- a/synapse/storage/schema/delta/23/drop_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql index e85699e82e..e85699e82e 100644 --- a/synapse/storage/schema/delta/30/state_stream.sql +++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql new file mode 100644 index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql index 0fce26345b..33980d02f0 100644 --- a/synapse/storage/schema/delta/35/add_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql @@ -13,8 +13,5 @@ * limitations under the License. */ - -ALTER TABLE background_updates ADD COLUMN depends_on TEXT; - INSERT into background_updates (update_name, progress_json, depends_on) VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql index 0f1fa68a89..0f1fa68a89 100644 --- a/synapse/storage/schema/delta/35/state.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql index 97e5067ef4..97e5067ef4 100644 --- a/synapse/storage/schema/delta/35/state_dedupe.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py index 9fd1ccf6f7..9fd1ccf6f7 100644 --- a/synapse/storage/schema/delta/47/state_group_seq.py +++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql new file mode 100644 index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql @@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644 index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py new file mode 100644 index 0000000000..f3ad1e4369 --- /dev/null +++ b/synapse/storage/data_stores/state/store.py @@ -0,0 +1,642 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple +from typing import Dict, Iterable, List, Set, Tuple + +from six import iteritems +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database +from synapse.storage.state import StateFilter +from synapse.types import StateMap +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000, + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", 500000, + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + @defer.inlineCallbacks + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): + """Returns the state groups for a given set of groups from the + database, filtering on types of state events. + + Args: + groups: list of state group IDs to query + state_filter: The state filter used to fetch state + from the database. + Returns: + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.db.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups: list of state groups for which we want + to get the state. + state_filter: The state filter used to fetch state + from the database. + Returns: + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache( + self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter + ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups: list of state groups for which we want to get the state. + cache: the cache of group ids to state dicts which + we will pass through - either the normal state cache or the + specific members state cache. + state_filter: The state filter used to fetch state from the + database. + + Returns: + Tuple of dict of state_group_id to state map of entries in the + cache, and the state group ids either missing from the cache or + incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self.db.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = { + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + } + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + @defer.inlineCallbacks + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self.db.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py new file mode 100644 index 0000000000..b112ff3df2 --- /dev/null +++ b/synapse/storage/database.py @@ -0,0 +1,1644 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from time import monotonic as monotonic_time +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, +) + +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range + +from prometheus_client import Histogram + +from twisted.enterprise import adbapi +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig +from synapse.logging.context import ( + LoggingContext, + LoggingContextOrSentinel, + current_context, + make_deferred_yieldable, +) +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdater +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Connection, Cursor +from synapse.types import Collection + +logger = logging.getLogger(__name__) + +# python 3 does not have a maximum int value +MAX_TXN_ID = 2 ** 63 - 1 + +sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") +perf_logger = logging.getLogger("synapse.storage.TIME") + +sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") + +sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) +sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) + + +# Unique indexes which have been added in background updates. Maps from table name +# to the name of the background update which added the unique index to that table. +# +# This is used by the upsert logic to figure out which tables are safe to do a proper +# UPSERT on: until the relevant background update has completed, we +# have to emulate an upsert by locking the table. +# +UNIQUE_INDEX_BACKGROUND_UPDATES = { + "user_ips": "user_ips_device_unique_index", + "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", + "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", + "event_search": "event_search_event_id_idx", +} + + +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn( + db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine +) -> Connection: + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + +# The type of entry which goes on our after_callbacks and exception_callbacks lists. +# +# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so +# that mypy sees the type but the runtime python doesn't. +_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] + + +class LoggingTransaction: + """An object that almost-transparently proxies for the 'txn' object + passed to the constructor. Adds logging and metrics to the .execute() + method. + + Args: + txn: The database transcation object to wrap. + name: The name of this transactions for logging. + database_engine + after_callbacks: A list that callbacks will be appended to + that have been added by `call_after` which should be run on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. + exception_callbacks: A list that callbacks will be appended + to that have been added by `call_on_exception` which should be run + if transaction ends with an error. None indicates that no callbacks + should be allowed to be scheduled to run. + """ + + __slots__ = [ + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", + ] + + def __init__( + self, + txn: Cursor, + name: str, + database_engine: BaseDatabaseEngine, + after_callbacks: Optional[List[_CallbackListEntry]] = None, + exception_callbacks: Optional[List[_CallbackListEntry]] = None, + ): + self.txn = txn + self.name = name + self.database_engine = database_engine + self.after_callbacks = after_callbacks + self.exception_callbacks = exception_callbacks + + def call_after(self, callback: "Callable[..., None]", *args, **kwargs): + """Call the given callback on the main twisted thread after the + transaction has finished. Used to invalidate the caches on the + correct thread. + """ + # if self.after_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.after_callbacks is not None + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): + # if self.exception_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.exception_callbacks is not None + self.exception_callbacks.append((callback, args, kwargs)) + + def fetchall(self) -> List[Tuple]: + return self.txn.fetchall() + + def fetchone(self) -> Tuple: + return self.txn.fetchone() + + def __iter__(self) -> Iterator[Tuple]: + return self.txn.__iter__() + + @property + def rowcount(self) -> int: + return self.txn.rowcount + + @property + def description(self) -> Any: + return self.txn.description + + def execute_batch(self, sql, args): + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import execute_batch # type: ignore + + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + else: + for val in args: + self.execute(sql, val) + + def execute(self, sql: str, *args: Any): + self._do_execute(self.txn.execute, sql, *args) + + def executemany(self, sql: str, *args: Any): + self._do_execute(self.txn.executemany, sql, *args) + + def _make_sql_one_line(self, sql: str) -> str: + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(line.strip() for line in sql.splitlines() if line.strip()) + + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + + # TODO(paul): Maybe use 'info' and 'debug' for values? + sql_logger.debug("[SQL] {%s} %s", self.name, sql) + + sql = self.database_engine.convert_param_style(sql) + if args: + try: + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + except Exception: + # Don't let logging failures stop SQL from working + pass + + start = time.time() + + try: + return func(sql, *args) + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise + finally: + secs = time.time() - start + sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) + sql_query_timer.labels(sql.split()[0]).observe(secs) + + def close(self): + self.txn.close() + + +class PerformanceCounters(object): + def __init__(self): + self.current_counters = {} + self.previous_counters = {} + + def update(self, key, duration_secs): + count, cum_time = self.current_counters.get(key, (0, 0)) + count += 1 + cum_time += duration_secs + self.current_counters[key] = (count, cum_time) + + def interval(self, interval_duration_secs, limit=3): + counters = [] + for name, (count, cum_time) in iteritems(self.current_counters): + prev_count, prev_time = self.previous_counters.get(name, (0, 0)) + counters.append( + ( + (cum_time - prev_time) / interval_duration_secs, + count - prev_count, + name, + ) + ) + + self.previous_counters = dict(self.current_counters) + + counters.sort(reverse=True) + + top_n_counters = ", ".join( + "%s(%d): %.3f%%" % (name, count, 100 * ratio) + for ratio, count, name in counters[:limit] + ) + + return top_n_counters + + +class Database(object): + """Wraps a single physical database and connection pool. + + A single database may be used by multiple data stores. + """ + + _TXN_ID = 0 + + def __init__( + self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + ): + self.hs = hs + self._clock = hs.get_clock() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) + + self.updates = BackgroundUpdater(hs, self) + + self._previous_txn_total_time = 0.0 + self._current_txn_total_time = 0.0 + self._previous_loop_ts = 0.0 + + # TODO(paul): These can eventually be removed once the metrics code + # is running in mainline, and we have some nice monitoring frontends + # to watch it + self._txn_perf_counters = PerformanceCounters() + + self.engine = engine + + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + + # We add the user_directory_search table to the blacklist on SQLite + # because the existing search table does not have an index, making it + # unsafe to use native upserts. + if isinstance(self.engine, Sqlite3Engine): + self._unsafe_to_upsert_tables.add("user_directory_search") + + if self.engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later( + 0.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait 15 sec and check again. + """ + updates = yield self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): + if update_name not in updates: + logger.debug("Now safe to upsert in %s", table) + self._unsafe_to_upsert_tables.discard(table) + + # If there's any updates still running, reschedule to run. + if updates: + self._clock.call_later( + 15.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + def start_profiling(self): + self._previous_loop_ts = monotonic_time() + + def loop(): + curr = self._current_txn_total_time + prev = self._previous_txn_total_time + self._previous_txn_total_time = curr + + time_now = monotonic_time() + time_then = self._previous_loop_ts + self._previous_loop_ts = time_now + + duration = time_now - time_then + ratio = (curr - prev) / duration + + top_three_counters = self._txn_perf_counters.interval(duration, limit=3) + + perf_logger.debug( + "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters + ) + + self._clock.looping_call(loop, 10000) + + def new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): + start = monotonic_time() + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) + + name = "%s-%x" % (desc, txn_id) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.engine, + after_callbacks, + exception_callbacks, + ) + try: + r = func(cursor, *args, **kwargs) + conn.commit() + return r + except self.engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warning( + "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.engine.module.Error as e1: + logger.warning("[TXN EROLL] {%s} %s", name, e1) + continue + raise + except self.engine.module.DatabaseError as e: + if self.engine.is_deadlock(e): + logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", name, e1, + ) + continue + raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = monotonic_time() + duration = end - start + + current_context().add_database_transaction(duration) + + transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, duration) + sql_txn_timer.labels(desc).observe(duration) + + @defer.inlineCallbacks + def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): + """Starts a transaction on the database and runs a given function + + Arguments: + desc: description of the transaction, for logging and metrics + func: callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args: positional args to pass to `func` + kwargs: named args to pass to `func` + + Returns: + Deferred: The result of func + """ + after_callbacks = [] # type: List[_CallbackListEntry] + exception_callbacks = [] # type: List[_CallbackListEntry] + + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) + + try: + result = yield self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs + ) + + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + return result + + @defer.inlineCallbacks + def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any): + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func: callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args: positional args to pass to `func` + kwargs: named args to pass to `func` + + Returns: + Deferred: The result of func + """ + parent_context = current_context() # type: Optional[LoggingContextOrSentinel] + if not parent_context: + logger.warning( + "Starting db connection from sentinel context: metrics will be lost" + ) + parent_context = None + + start_time = monotonic_time() + + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection", parent_context) as context: + sched_duration_sec = monotonic_time() - start_time + sql_scheduling_timer.observe(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) + + if self.engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() + + return func(conn, *args, **kwargs) + + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) + + return result + + @staticmethod + def cursor_to_dict(cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = [intern(str(column[0])) for column in cursor.description] + results = [dict(zip(col_headers, row)) for row in cursor] + return results + + def execute(self, desc, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + + def interaction(txn): + txn.execute(query, args) + if decoder: + return decoder(txn) + else: + return txn.fetchall() + + return self.runInteraction(desc, interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + @defer.inlineCallbacks + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + or_ignore : bool stating whether an exception should be raised + when a conflicting row already exists. If True, False will be + returned by the function instead + desc : string giving a description of the transaction + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True + """ + try: + yield self.runInteraction(desc, self.simple_insert_txn, table, values) + except self.engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise + return False + return True + + @staticmethod + def simple_insert_txn(txn, table, values): + keys, vals = zip(*values.items()) + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute(sql, vals) + + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + + @staticmethod + def simple_insert_many_txn(txn, table, values): + if not values: + return + + # This is a *slight* abomination to get a list of tuples of key names + # and a list of tuples of value names. + # + # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] + # + # The sort is to ensure that we don't rely on dictionary iteration + # order. + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) + + for k in keys: + if k != keys[0]: + raise RuntimeError("All items must have the same keys") + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.executemany(sql, vals) + + @defer.inlineCallbacks + def simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="simple_upsert", + lock=True, + ): + """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key columns and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self.simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, + ) + return result + except self.engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warning( + "IntegrityError when upserting into %s; retrying: %s", table, e + ) + + def simple_upsert_txn( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + None or bool: Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + return self.simple_upsert_txn_native_upsert( + txn, table, keyvalues, values, insertion_values=insertion_values + ) + else: + return self.simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + bool: Return True if a new entry was created, False if an existing + one was updated. + """ + # We need to lock the table :(, unless we're *really* careful + if lock: + self.engine.lock_table(txn, table) + + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + + if not values: + # If `values` is empty, then all of the values we care about are in + # the unique key, so there is nothing to UPDATE. We can just do a + # SELECT instead to see if it exists. + sql = "SELECT 1 FROM %s WHERE %s" % ( + table, + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(keyvalues.values()) + txn.execute(sql, sqlargs) + if txn.fetchall(): + # We have an existing record. + return False + else: + # First try to update. + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(values.values()) + list(keyvalues.values()) + + txn.execute(sql, sqlargs) + if txn.rowcount > 0: + # successfully updated at least one row. + return False + + # We didn't find any existing rows, so insert a new one + allvalues = {} # type: Dict[str, Any] + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ) + txn.execute(sql, list(allvalues.values())) + # successfully inserted + return True + + def simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} # type: Dict[str, Any] + allvalues.update(keyvalues) + allvalues.update(insertion_values) + + if not values: + latter = "NOTHING" + else: + allvalues.update(values) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) + + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + latter, + ) + txn.execute(sql, list(allvalues.values())) + + def simple_upsert_many_txn( + self, + txn: LoggingTransaction, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[str]], + ) -> None: + """ + Upsert, many times. + + Args: + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. + """ + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + return self.simple_upsert_many_txn_native_upsert( + txn, table, key_names, key_values, value_names, value_values + ) + else: + return self.simple_upsert_many_txn_emulated( + txn, table, key_names, key_values, value_names, value_values + ) + + def simple_upsert_many_txn_emulated( + self, + txn: LoggingTransaction, + table: str, + key_names: Iterable[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[str]], + ) -> None: + """ + Upsert, many times, but without native UPSERT support or batching. + + Args: + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. + """ + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + + for keyv, valv in zip(key_values, value_values): + _keys = {x: y for x, y in zip(key_names, keyv)} + _vals = {x: y for x, y in zip(value_names, valv)} + + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) + + def simple_upsert_many_txn_native_upsert( + self, + txn: LoggingTransaction, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[Any]], + ) -> None: + """ + Upsert, many times, using batching where possible. + + Args: + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. + """ + allnames = [] # type: List[str] + allnames.extend(key_names) + allnames.extend(value_names) + + if not value_names: + # No value columns, therefore make a blank list so that the + # following zip() works correctly. + latter = "NOTHING" + value_values = [() for x in range(len(key_values))] + else: + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names + ) + + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + args = [] + + for x, y in zip(key_values, value_values): + args.append(tuple(x) + tuple(y)) + + return txn.execute_batch(sql, args) + + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning multiple columns from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self.runInteraction( + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + ) + + def simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="simple_select_one_onecol", + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + return self.runInteraction( + desc, + self.simple_select_one_onecol_txn, + table, + keyvalues, + retcol, + allow_none=allow_none, + ) + + @classmethod + def simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): + ret = cls.simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, retcol=retcol + ) + + if ret: + return ret[0] + else: + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + @staticmethod + def simple_select_onecol_txn(txn, table, keyvalues, retcol): + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} + + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + txn.execute(sql, list(keyvalues.values())) + else: + txn.execute(sql) + + return [r[0] for r in txn] + + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" + ): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict|None): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + return self.runInteraction( + desc, self.simple_select_onecol_txn, table, keyvalues, retcol + ) + + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, self.simple_select_list_txn, table, keyvalues, retcols + ) + + @classmethod + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + txn.execute(sql, list(keyvalues.values())) + else: + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + txn.execute(sql) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="simple_select_many_batch", + batch_size=100, + ): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + results = [] # type: List[Dict[str, Any]] + + if not iterable: + return results + + # iterables can not be sliced, so convert it to a list first + it_list = list(iterable) + + chunks = [ + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) + ] + for chunk in chunks: + rows = yield self.runInteraction( + desc, + self.simple_select_many_txn, + table, + column, + chunk, + keyvalues, + retcols, + ) + + results.extend(rows) + + return results + + @classmethod + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + if not iterable: + return [] + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) + + txn.execute(sql, values) + return cls.cursor_to_dict(txn) + + def simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, self.simple_update_txn, table, keyvalues, updatevalues + ) + + @staticmethod + def simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) + + return txn.rowcount + + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" + ): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self.runInteraction( + desc, self.simple_update_one_txn, table, keyvalues, updatevalues + ) + + @classmethod + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) + + if rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + @staticmethod + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(select_sql, list(keyvalues.values())) + row = txn.fetchone() + + if not row: + if allow_none: + return None + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + return dict(zip(retcols, row)) + + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + + @staticmethod + def simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + if txn.rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + + @staticmethod + def simple_delete_txn(txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount + + def simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted + """ + if not iterable: + return 0 + + sql = "DELETE FROM %s" % table + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + txn.execute(sql, values) + + return txn.rowcount + + def get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - %(limit)s" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + "limit": limit, + } + + sql = self.engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + + cache = {row[0]: int(row[1]) for row in txn} + + txn.close() + + if cache: + min_val = min(itervalues(cache)) + else: + min_val = max_value + + return cache, min_val + + def simple_select_list_paginate( + self, + table, + orderby, + start, + limit, + retcols, + filters=None, + keyvalues=None, + order_direction="ASC", + desc="simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self.simple_select_list_paginate_txn, + table, + orderby, + start, + limit, + retcols, + filters=filters, + keyvalues=keyvalues, + order_direction=order_direction, + ) + + @classmethod + def simple_select_list_paginate_txn( + cls, + txn, + table, + orderby, + start, + limit, + retcols, + filters=None, + keyvalues=None, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to + select attributes with exact matches. All constraints are joined together + using 'AND'. + + Args: + txn : Transaction object + table (str): the table name + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + + where_clause = "WHERE " if filters or keyvalues else "" + arg_list = [] # type: List[Any] + if filters: + where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) + arg_list += list(filters.values()) + where_clause += " AND " if filters and keyvalues else "" + if keyvalues: + where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues) + arg_list += list(keyvalues.values()) + + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, + ) + txn.execute(sql, arg_list + [limit, start]) + + return cls.cursor_to_dict(txn) + + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, self.simple_search_list_txn, table, term, col, retcols + ) + + @classmethod + def simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, list]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if database_engine.supports_using_any_list: + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) + + +KV = TypeVar("KV") + + +def make_tuple_comparison_clause( + database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]] +) -> Tuple[str, List[KV]]: + """Returns a tuple comparison SQL clause + + Depending what the SQL engine supports, builds a SQL clause that looks like either + "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)". + + Args: + database_engine + keys: A set of (column, value) pairs to be compared. + + Returns: + A tuple of SQL query and the args + """ + if database_engine.supports_tuple_comparison: + return ( + "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), + [k[1] for k in keys], + ) + + # we want to build a clause + # (a > ?) OR + # (a == ? AND b > ?) OR + # (a == ? AND b == ? AND c > ?) + # ... + # (a == ? AND b == ? AND ... AND z > ?) + # + # or, equivalently: + # + # (a > ? OR (a == ? AND + # (b > ? OR (b == ? AND + # ... + # (y > ? OR (y == ? AND + # z > ? + # )) + # ... + # )) + # )) + # + # which itself is equivalent to (and apparently easier for the query optimiser): + # + # (a >= ? AND (a > ? OR + # (b >= ? AND (b > ? OR + # ... + # (y >= ? AND (y > ? OR + # z > ? + # )) + # ... + # )) + # )) + # + # + + clause = "" + args = [] # type: List[KV] + for k, v in keys[:-1]: + clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k) + args.extend([v, v]) + + (k, v) = keys[-1] + clause += "%s > ?" % (k,) + args.append(v) + + clause += "))" * (len(keys) - 1) + return clause, args diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py deleted file mode 100644 index 33e3a84933..0000000000 --- a/synapse/storage/end_to_end_keys.py +++ /dev/null @@ -1,313 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from six import iteritems - -from canonicaljson import encode_canonical_json - -from twisted.internet import defer - -from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.util.caches.descriptors import cached - -from ._base import SQLBaseStore, db_to_json - - -class EndToEndKeyWorkerStore(SQLBaseStore): - @trace - @defer.inlineCallbacks - def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. - Args: - query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. - Returns: - Dict mapping from user-id to dict mapping from device_id to - dict containing "key_json", "device_display_name". - """ - set_tag("query_list", query_list) - if not query_list: - return {} - - results = yield self.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, - ) - - for user_id, device_keys in iteritems(results): - for device_id, device_info in iteritems(device_keys): - device_info["keys"] = db_to_json(device_info.pop("key_json")) - - return results - - @trace - def _get_e2e_device_keys_txn( - self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ): - set_tag("include_all_devices", include_all_devices) - set_tag("include_deleted_devices", include_deleted_devices) - - query_clauses = [] - query_params = [] - - if include_all_devices is False: - include_deleted_devices = False - - if include_deleted_devices: - deleted_devices = set(query_list) - - for (user_id, device_id) in query_list: - query_clause = "user_id = ?" - query_params.append(user_id) - - if device_id is not None: - query_clause += " AND device_id = ?" - query_params.append(device_id) - - query_clauses.append(query_clause) - - sql = ( - "SELECT user_id, device_id, " - " d.display_name AS device_display_name, " - " k.key_json" - " FROM devices d" - " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s" - ) % ( - "LEFT" if include_all_devices else "INNER", - " OR ".join("(" + q + ")" for q in query_clauses), - ) - - txn.execute(sql, query_params) - rows = self.cursor_to_dict(txn) - - result = {} - for row in rows: - if include_deleted_devices: - deleted_devices.remove((row["user_id"], row["device_id"])) - result.setdefault(row["user_id"], {})[row["device_id"]] = row - - if include_deleted_devices: - for user_id, device_id in deleted_devices: - result.setdefault(user_id, {})[device_id] = None - - log_kv(result) - return result - - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): - """Retrieve a number of one-time keys for a user - - Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - key_ids(list[str]): list of key ids (excluding algorithm) to - retrieve - - Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key - """ - - rows = yield self._simple_select_many_batch( - table="e2e_one_time_keys_json", - column="key_id", - iterable=key_ids, - retcols=("algorithm", "key_id", "key_json"), - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="add_e2e_one_time_keys_check", - ) - result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} - log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) - return result - - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): - """Insert some new one time keys for a device. Errors if any of the - keys already exist. - - Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) - """ - - def _add_e2e_one_time_keys(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("new_keys", new_keys) - # We are protected from race between lookup and insertion due to - # a unique constraint. If there is a race of two calls to - # `add_e2e_one_time_keys` then they'll conflict and we will only - # insert one set. - self._simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - values=[ - { - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - "key_id": key_id, - "ts_added_ms": time_now, - "key_json": json_bytes, - } - for algorithm, key_id, json_bytes in new_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - yield self.runInteraction( - "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys - ) - - @cached(max_entries=10000) - def count_e2e_one_time_keys(self, user_id, device_id): - """ Count the number of one time keys the server has for a device - Returns: - Dict mapping from algorithm to number of keys for that algorithm. - """ - - def _count_e2e_one_time_keys(txn): - sql = ( - "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ?" - " GROUP BY algorithm" - ) - txn.execute(sql, (user_id, device_id)) - result = {} - for algorithm, key_count in txn: - result[algorithm] = key_count - return result - - return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) - - -class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): - """Stores device keys for a device. Returns whether there was a change - or the keys were already in the database. - """ - - def _set_e2e_device_keys_txn(txn): - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", device_keys) - - old_key_json = self._simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False - - self._simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True - - return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) - - def claim_e2e_one_time_keys(self, query_list): - """Take a list of one time keys out of the database""" - - @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" - ) - result = {} - delete = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" - ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - return result - - return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) - - def delete_e2e_keys_by_device(self, user_id, device_id): - def delete_e2e_keys_by_device_txn(txn): - log_kv( - { - "message": "Deleting keys for device", - "device_id": device_id, - "user_id": user_id, - } - ) - self._simple_delete_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._simple_delete_txn( - txn, - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - return self.runInteraction( - "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn - ) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 9d2d519922..035f9ea6e9 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -12,29 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import importlib import platform -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} - -def create_engine(database_config): +def create_engine(database_config) -> BaseDatabaseEngine: name = database_config["name"] - engine_class = SUPPORTED_MODULE.get(name, None) - if engine_class: + if name == "sqlite3": + import sqlite3 + + return Sqlite3Engine(sqlite3, database_config) + + if name == "psycopg2": # pypy requires psycopg2cffi rather than psycopg2 - if name == "psycopg2" and platform.python_implementation() == "PyPy": - name = "psycopg2cffi" - module = importlib.import_module(name) - return engine_class(module, database_config) + if platform.python_implementation() == "PyPy": + import psycopg2cffi as psycopg2 # type: ignore + else: + import psycopg2 # type: ignore + + return PostgresEngine(psycopg2, database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) -__all__ = ["create_engine", "IncorrectDatabaseSetup"] +__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"] diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ec5a4d198b..ab0bbe4bd3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -12,7 +12,94 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc +from typing import Generic, TypeVar + +from synapse.storage.types import Connection class IncorrectDatabaseSetup(RuntimeError): pass + + +ConnectionType = TypeVar("ConnectionType", bound=Connection) + + +class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): + def __init__(self, module, database_config: dict): + self.module = module + + @property + @abc.abstractmethod + def single_threaded(self) -> bool: + ... + + @property + @abc.abstractmethod + def can_native_upsert(self) -> bool: + """ + Do we support native UPSERTs? + """ + ... + + @property + @abc.abstractmethod + def supports_tuple_comparison(self) -> bool: + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + ... + + @property + @abc.abstractmethod + def supports_using_any_list(self) -> bool: + """ + Do we support using `a = ANY(?)` and passing a list + """ + ... + + @abc.abstractmethod + def check_database( + self, db_conn: ConnectionType, allow_outdated_version: bool = False + ) -> None: + ... + + @abc.abstractmethod + def check_new_database(self, txn) -> None: + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + ... + + @abc.abstractmethod + def convert_param_style(self, sql: str) -> str: + ... + + @abc.abstractmethod + def on_new_connection(self, db_conn: ConnectionType) -> None: + ... + + @abc.abstractmethod + def is_deadlock(self, error: Exception) -> bool: + ... + + @abc.abstractmethod + def is_connection_closed(self, conn: ConnectionType) -> bool: + ... + + @abc.abstractmethod + def lock_table(self, txn, table: str) -> None: + ... + + @abc.abstractmethod + def get_next_state_group_id(self, txn) -> int: + """Returns an int that can be used as a new state_group ID + """ + ... + + @property + @abc.abstractmethod + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0' + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 289b6bc281..6c7d08a6f2 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,32 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import IncorrectDatabaseSetup +import logging +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup -class PostgresEngine(object): - single_threaded = False +logger = logging.getLogger(__name__) + +class PostgresEngine(BaseDatabaseEngine): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) self.module.extensions.register_type(self.module.extensions.UNICODE) - self.synchronous_commit = database_config.get("synchronous_commit", True) - self._version = None # unknown as yet - def check_database(self, txn): - txn.execute("SHOW SERVER_ENCODING") - rows = txn.fetchall() - if rows and rows[0][0] != "UTF8": - raise IncorrectDatabaseSetup( - "Database has incorrect encoding: '%s' instead of 'UTF8'\n" - "See docs/postgres.rst for more information." % (rows[0][0],) - ) + # Disables passing `bytes` to txn.execute, c.f. #6186. If you do + # actually want to use bytes than wrap it in `bytearray`. + def _disable_bytes_adapter(_): + raise Exception("Passing bytes to DB is disabled.") - def convert_param_style(self, sql): - return sql.replace("?", "%s") + self.module.extensions.register_adapter(bytes, _disable_bytes_adapter) + self.synchronous_commit = database_config.get("synchronous_commit", True) + self._version = None # unknown as yet - def on_new_connection(self, db_conn): + @property + def single_threaded(self) -> bool: + return False + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them @@ -46,9 +46,64 @@ class PostgresEngine(object): self._version = db_conn.server_version # Are we on a supported PostgreSQL version? - if self._version < 90500: + if not allow_outdated_version and self._version < 90500: raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.") + with db_conn.cursor() as txn: + txn.execute("SHOW SERVER_ENCODING") + rows = txn.fetchall() + if rows and rows[0][0] != "UTF8": + raise IncorrectDatabaseSetup( + "Database has incorrect encoding: '%s' instead of 'UTF8'\n" + "See docs/postgres.md for more information." % (rows[0][0],) + ) + + txn.execute( + "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" + ) + collation, ctype = txn.fetchone() + if collation != "C": + logger.warning( + "Database has incorrect collation of %r. Should be 'C'\n" + "See docs/postgres.md for more information.", + collation, + ) + + if ctype != "C": + logger.warning( + "Database has incorrect ctype of %r. Should be 'C'\n" + "See docs/postgres.md for more information.", + ctype, + ) + + def check_new_database(self, txn): + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + + txn.execute( + "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" + ) + collation, ctype = txn.fetchone() + + errors = [] + + if collation != "C": + errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,)) + + if ctype != "C": + errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,)) + + if errors: + raise IncorrectDatabaseSetup( + "Database is incorrectly configured:\n\n%s\n\n" + "See docs/postgres.md for more information." % ("\n".join(errors)) + ) + + def convert_param_style(self, sql): + return sql.replace("?", "%s") + + def on_new_connection(self, db_conn): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) @@ -72,6 +127,19 @@ class PostgresEngine(object): """ return True + @property + def supports_tuple_comparison(self): + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + return True + + @property + def supports_using_any_list(self): + """Do we support using `a = ANY(?)` and passing a list + """ + return True + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html @@ -99,8 +167,8 @@ class PostgresEngine(object): Returns: string """ - # note that this is a bit of a hack because it relies on on_new_connection - # having been called at least once. Still, that should be a safe bet here. + # note that this is a bit of a hack because it relies on check_database + # having been called. Still, that should be a safe bet here. numver = self._version assert numver is not None diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index e9b9caa49a..215a949442 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,18 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import struct import threading +import typing -from synapse.storage.prepare_database import prepare_database +from synapse.storage.engines import BaseDatabaseEngine +if typing.TYPE_CHECKING: + import sqlite3 # noqa: F401 -class Sqlite3Engine(object): - single_threaded = True +class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) + + database = database_config.get("args", {}).get("database") + self._is_in_memory = database in (None, ":memory:",) # The current max state_group, or None if we haven't looked # in the DB yet. @@ -31,6 +35,10 @@ class Sqlite3Engine(object): self._current_state_group_id_lock = threading.Lock() @property + def single_threaded(self) -> bool: + return True + + @property def can_native_upsert(self): """ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some @@ -38,15 +46,46 @@ class Sqlite3Engine(object): """ return self.module.sqlite_version_info >= (3, 24, 0) - def check_database(self, txn): - pass + @property + def supports_tuple_comparison(self): + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires + SQLite 3.15+. + """ + return self.module.sqlite_version_info >= (3, 15, 0) + + @property + def supports_using_any_list(self): + """Do we support using `a = ANY(?)` and passing a list + """ + return False + + def check_database(self, db_conn, allow_outdated_version: bool = False): + if not allow_outdated_version: + version = self.module.sqlite_version_info + if version < (3, 11, 0): + raise RuntimeError("Synapse requires sqlite 3.11 or above.") + + def check_new_database(self, txn): + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ def convert_param_style(self, sql): return sql def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + + if self._is_in_memory: + # In memory databases need to be rebuilt each time. Ideally we'd + # reuse the same connection as we do when starting up, but that + # would involve using adbapi before we have started the reactor. + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) + db_conn.execute("PRAGMA foreign_keys = ON;") def is_deadlock(self, error): return False diff --git a/synapse/storage/events.py b/synapse/storage/events.py deleted file mode 100644 index ddf7ab6479..0000000000 --- a/synapse/storage/events.py +++ /dev/null @@ -1,2470 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -from collections import Counter as c_counter, OrderedDict, deque, namedtuple -from functools import wraps - -from six import iteritems, text_type -from six.moves import range - -from canonicaljson import encode_canonical_json, json -from prometheus_client import Counter, Histogram - -from twisted.internet import defer - -import synapse.metrics -from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 -from synapse.events.utils import prune_event_dict -from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable -from synapse.logging.utils import log_function -from synapse.metrics import BucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.event_federation import EventFederationStore -from synapse.storage.events_worker import EventsWorkerStore -from synapse.storage.state import StateGroupWorkerStore -from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util import batch_iter -from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.util.frozenutils import frozendict_json_encoder -from synapse.util.metrics import Measure - -logger = logging.getLogger(__name__) - -persist_event_counter = Counter("synapse_storage_events_persisted_events", "") -event_counter = Counter( - "synapse_storage_events_persisted_events_sep", - "", - ["type", "origin_type", "origin_entity"], -) - -# The number of times we are recalculating the current state -state_delta_counter = Counter("synapse_storage_events_state_delta", "") - -# The number of times we are recalculating state when there is only a -# single forward extremity -state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "" -) - -# The number of times we are reculating state when we could have resonably -# calculated the delta when we calculated the state for an event we were -# persisting. -state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "" -) - -# The number of forward extremities for each new event. -forward_extremities_counter = Histogram( - "synapse_storage_events_forward_extremities_persisted", - "Number of forward extremities for each new event", - buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - -# The number of stale forward extremities for each new event. Stale extremities -# are those that were in the previous set of extremities as well as the new. -stale_forward_extremities_counter = Histogram( - "synapse_storage_events_stale_forward_extremities_persisted", - "Number of unchanged forward extremities for each new event", - buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), -) - - -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - -class _EventPeristenceQueue(object): - """Queues up events so that they can be persisted in bulk with only one - concurrent transaction per room. - """ - - _EventPersistQueueItem = namedtuple( - "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") - ) - - def __init__(self): - self._event_persist_queues = {} - self._currently_persisting_rooms = set() - - def add_to_queue(self, room_id, events_and_contexts, backfilled): - """Add events to the queue, with the given persist_event options. - - NB: due to the normal usage pattern of this method, it does *not* - follow the synapse logcontext rules, and leaves the logcontext in - place whether or not the returned deferred is ready. - - Args: - room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - - Returns: - defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks *without* a logcontext. - """ - queue = self._event_persist_queues.setdefault(room_id, deque()) - if queue: - # if the last item in the queue has the same `backfilled` setting, - # we can just add these new events to that item. - end_item = queue[-1] - if end_item.backfilled == backfilled: - end_item.events_and_contexts.extend(events_and_contexts) - return end_item.deferred.observe() - - deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - - queue.append( - self._EventPersistQueueItem( - events_and_contexts=events_and_contexts, - backfilled=backfilled, - deferred=deferred, - ) - ) - - return deferred.observe() - - def handle_queue(self, room_id, per_item_callback): - """Attempts to handle the queue for a room if not already being handled. - - The given callback will be invoked with for each item in the queue, - of type _EventPersistQueueItem. The per_item_callback will continuously - be called with new items, unless the queue becomnes empty. The return - value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferreds as well. - - This function should therefore be called whenever anything is added - to the queue. - - If another callback is currently handling the queue then it will not be - invoked. - """ - - if room_id in self._currently_persisting_rooms: - return - - self._currently_persisting_rooms.add(room_id) - - @defer.inlineCallbacks - def handle_queue_loop(): - try: - queue = self._get_drainining_queue(room_id) - for item in queue: - try: - ret = yield per_item_callback(item) - except Exception: - with PreserveLoggingContext(): - item.deferred.errback() - else: - with PreserveLoggingContext(): - item.deferred.callback(ret) - finally: - queue = self._event_persist_queues.pop(room_id, None) - if queue: - self._event_persist_queues[room_id] = queue - self._currently_persisting_rooms.discard(room_id) - - # set handle_queue_loop off in the background - run_as_background_process("persist_events", handle_queue_loop) - - def _get_drainining_queue(self, room_id): - queue = self._event_persist_queues.setdefault(room_id, deque()) - - try: - while True: - yield queue.popleft() - except IndexError: - # Queue has been drained. - pass - - -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) - - -def _retry_on_integrity_error(func): - """Wraps a database function so that it gets retried on IntegrityError, - with `delete_existing=True` passed in. - - Args: - func: function that returns a Deferred and accepts a `delete_existing` arg - """ - - @wraps(func) - @defer.inlineCallbacks - def f(self, *args, **kwargs): - try: - res = yield func(self, *args, **kwargs) - except self.database_engine.module.IntegrityError: - logger.exception("IntegrityError, retrying.") - res = yield func(self, *args, delete_existing=True, **kwargs) - return res - - return f - - -# inherits from EventFederationStore so that we can call _update_backward_extremities -# and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore( - StateGroupWorkerStore, - EventFederationStore, - EventsWorkerStore, - BackgroundUpdateStore, -): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) - - self._event_persist_queue = _EventPeristenceQueue() - self._state_resolution_handler = hs.get_state_resolution_handler() - - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = c_counter() - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - - # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) - - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) - - def _censor_redactions(): - return run_as_background_process( - "_censor_redactions", self._censor_redactions - ) - - if self.hs.config.redaction_retention_period is not None: - hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) - - @defer.inlineCallbacks - def _read_forward_extremities(self): - def fetch(txn): - txn.execute( - """ - select count(*) c from event_forward_extremities - group by room_id - """ - ) - return txn.fetchall() - - res = yield self.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) - - @defer.inlineCallbacks - def persist_events(self, events_and_contexts, backfilled=False): - """ - Write events to the database - Args: - events_and_contexts: list of tuples of (event, context) - backfilled (bool): Whether the results are retrieved from federation - via backfill or not. Used to determine if they're "new" events - which might update the current state etc. - - Returns: - Deferred[int]: the stream ordering of the latest persisted event - """ - partitioned = {} - for event, ctx in events_and_contexts: - partitioned.setdefault(event.room_id, []).append((event, ctx)) - - deferreds = [] - for room_id, evs_ctxs in iteritems(partitioned): - d = self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, backfilled=backfilled - ) - deferreds.append(d) - - for room_id in partitioned: - self._maybe_start_persisting(room_id) - - yield make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - - return max_persisted_id - - @defer.inlineCallbacks - @log_function - def persist_event(self, event, context, backfilled=False): - """ - - Args: - event (EventBase): - context (EventContext): - backfilled (bool): - - Returns: - Deferred: resolves to (int, int): the stream ordering of ``event``, - and the stream ordering of the latest persisted event - """ - deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled - ) - - self._maybe_start_persisting(event.room_id) - - yield make_deferred_yieldable(deferred) - - max_persisted_id = yield self._stream_id_gen.get_current_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) - - def _maybe_start_persisting(self, room_id): - @defer.inlineCallbacks - def persisting_queue(item): - with Measure(self._clock, "persist_events"): - yield self._persist_events( - item.events_and_contexts, backfilled=item.backfilled - ) - - self._event_persist_queue.handle_queue(room_id, persisting_queue) - - @_retry_on_integrity_error - @defer.inlineCallbacks - def _persist_events( - self, events_and_contexts, backfilled=False, delete_existing=False - ): - """Persist events to db - - Args: - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): - delete_existing (bool): - - Returns: - Deferred: resolves when the events have been persisted - """ - if not events_and_contexts: - return - - chunks = [ - events_and_contexts[x : x + 100] - for x in range(0, len(events_and_contexts), 100) - ] - - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - - # NB: Assumes that we are only persisting events for one room - # at a time. - - # map room_id->list[event_ids] giving the new forward - # extremities in each room - new_forward_extremeties = {} - - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room = {} - - # map room_id->(to_delete, to_insert) where to_delete is a list - # of type/state keys to remove from current state, and to_insert - # is a map (type,key)->event_id giving the state delta in each - # room - state_delta_for_room = {} - - if not backfilled: - with Measure(self._clock, "_calculate_state_and_extrem"): - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in iteritems(events_by_room): - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremities( - room_id, ev_ctx_rm, latest_event_ids - ) - - latest_event_ids = set(latest_event_ids) - if new_latest_event_ids == latest_event_ids: - # No change in extremities, so no change in state - continue - - # there should always be at least one forward extremity. - # (except during the initial persistence of the send_join - # results, in which case there will be no existing - # extremities, so we'll `continue` above and skip this bit.) - assert new_latest_event_ids, "No forward extremities left!" - - new_forward_extremeties[room_id] = new_latest_event_ids - - len_1 = ( - len(latest_event_ids) == 1 - and len(new_latest_event_ids) == 1 - ) - if len_1: - all_single_prev_not_state = all( - len(event.prev_event_ids()) == 1 - and not event.is_state() - for event, ctx in ev_ctx_rm - ) - # Don't bother calculating state if they're just - # a long chain of single ancestor non-state events. - if all_single_prev_not_state: - continue - - state_delta_counter.inc() - if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() - - # This is a fairly handwavey check to see if we could - # have guessed what the delta would have been when - # processing one of these events. - # What we're interested in is if the latest extremities - # were the same when we created the event as they are - # now. When this server creates a new event (as opposed - # to receiving it over federation) it will use the - # forward extremities as the prev_events, so we can - # guess this by looking at the prev_events and checking - # if they match the current forward extremities. - for ev, _ in ev_ctx_rm: - prev_event_ids = set(ev.prev_event_ids()) - if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() - break - - logger.info("Calculating state delta for room %s", room_id) - with Measure( - self._clock, "persist_events.get_new_state_after_events" - ): - res = yield self._get_new_state_after_events( - room_id, - ev_ctx_rm, - latest_event_ids, - new_latest_event_ids, - ) - current_state, delta_ids = res - - # If either are not None then there has been a change, - # and we need to work out the delta (or use that - # given) - if delta_ids is not None: - # If there is a delta we know that we've - # only added or replaced state, never - # removed keys entirely. - state_delta_for_room[room_id] = ([], delta_ids) - elif current_state is not None: - with Measure( - self._clock, "persist_events.calculate_state_delta" - ): - delta = yield self._calculate_state_delta( - room_id, current_state - ) - state_delta_for_room[room_id] = delta - - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state - - # We want to calculate the stream orderings as late as possible, as - # we only notify after all events with a lesser stream ordering have - # been persisted. I.e. if we spend 10s inside the with block then - # that will delay all subsequent events from being notified about. - # Hence why we do it down here rather than wrapping the entire - # function. - # - # Its safe to do this after calculating the state deltas etc as we - # only need to protect the *persistence* of the events. This is to - # ensure that queries of the form "fetch events since X" don't - # return events and stream positions after events that are still in - # flight, as otherwise subsequent requests "fetch event since Y" - # will not return those events. - # - # Note: Multiple instances of this function cannot be in flight at - # the same time for the same room. - if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( - len(chunk) - ) - else: - stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk)) - - with stream_ordering_manager as stream_orderings: - for (event, context), stream in zip(chunk, stream_orderings): - event.internal_metadata.stream_ordering = stream - - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - state_delta_for_room=state_delta_for_room, - new_forward_extremeties=new_forward_extremeties, - ) - persist_event_counter.inc(len(chunk)) - - if not backfilled: - # backfilled events have negative stream orderings, so we don't - # want to set the event_persisted_position to that. - synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering - ) - - for event, context in chunk: - if context.app_service: - origin_type = "local" - origin_entity = context.app_service.id - elif self.hs.is_mine_id(event.sender): - origin_type = "local" - origin_entity = "*client*" - else: - origin_type = "remote" - origin_entity = get_domain_from_id(event.sender) - - event_counter.labels(event.type, origin_type, origin_entity).inc() - - for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill((room_id,), new_state) - - for room_id, latest_event_ids in iteritems(new_forward_extremeties): - self.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) - ) - - @defer.inlineCallbacks - def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremities for a room given events to - persist. - - Assumes that we are only persisting events for one room at a time. - """ - - # we're only interested in new events which aren't outliers and which aren't - # being rejected. - new_events = [ - event - for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() - and not ctx.rejected - and not event.internal_metadata.is_soft_failed() - ] - - latest_event_ids = set(latest_event_ids) - - # start with the existing forward extremities - result = set(latest_event_ids) - - # add all the new events to the list - result.update(event.event_id for event in new_events) - - # Now remove all events which are prev_events of any of the new events - result.difference_update( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - - # Remove any events which are prev_events of any existing events. - existing_prevs = yield self._get_events_which_are_prevs(result) - result.difference_update(existing_prevs) - - # Finally handle the case where the new events have soft-failed prev - # events. If they do we need to remove them and their prev events, - # otherwise we end up with dangling extremities. - existing_prevs = yield self._get_prevs_before_rejected( - e_id for event in new_events for e_id in event.prev_event_ids() - ) - result.difference_update(existing_prevs) - - # We only update metrics for events that change forward extremities - # (e.g. we ignore backfill/outliers/etc) - if result != latest_event_ids: - forward_extremities_counter.observe(len(result)) - stale = latest_event_ids & result - stale_forward_extremities_counter.observe(len(stale)) - - return result - - @defer.inlineCallbacks - def _get_events_which_are_prevs(self, event_ids): - """Filter the supplied list of event_ids to get those which are prev_events of - existing (non-outlier/rejected) events. - - Args: - event_ids (Iterable[str]): event ids to filter - - Returns: - Deferred[List[str]]: filtered event ids - """ - results = [] - - def _get_events_which_are_prevs_txn(txn, batch): - sql = """ - SELECT prev_event_id, internal_metadata - FROM event_edges - INNER JOIN events USING (event_id) - LEFT JOIN rejections USING (event_id) - LEFT JOIN event_json USING (event_id) - WHERE - prev_event_id IN (%s) - AND NOT events.outlier - AND rejections.event_id IS NULL - """ % ( - ",".join("?" for _ in batch), - ) - - txn.execute(sql, batch) - results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) - - for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( - "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk - ) - - return results - - @defer.inlineCallbacks - def _get_prevs_before_rejected(self, event_ids): - """Get soft-failed ancestors to remove from the extremities. - - Given a set of events, find all those that have been soft-failed or - rejected. Returns those soft failed/rejected events and their prev - events (whether soft-failed/rejected or not), and recurses up the - prev-event graph until it finds no more soft-failed/rejected events. - - This is used to find extremities that are ancestors of new events, but - are separated by soft failed events. - - Args: - event_ids (Iterable[str]): Events to find prev events for. Note - that these must have already been persisted. - - Returns: - Deferred[set[str]] - """ - - # The set of event_ids to return. This includes all soft-failed events - # and their prev events. - existing_prevs = set() - - def _get_prevs_before_rejected_txn(txn, batch): - to_recursively_check = batch - - while to_recursively_check: - sql = """ - SELECT - event_id, prev_event_id, internal_metadata, - rejections.event_id IS NOT NULL - FROM event_edges - INNER JOIN events USING (event_id) - LEFT JOIN rejections USING (event_id) - LEFT JOIN event_json USING (event_id) - WHERE - event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_recursively_check), - ) - - txn.execute(sql, to_recursively_check) - to_recursively_check = [] - - for event_id, prev_event_id, metadata, rejected in txn: - if prev_event_id in existing_prevs: - continue - - soft_failed = json.loads(metadata).get("soft_failed") - if soft_failed or rejected: - to_recursively_check.append(prev_event_id) - existing_prevs.add(prev_event_id) - - for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( - "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk - ) - - return existing_prevs - - @defer.inlineCallbacks - def _get_new_state_after_events( - self, room_id, events_context, old_latest_event_ids, new_latest_event_ids - ): - """Calculate the current state dict after adding some new events to - a room - - Args: - room_id (str): - room to which the events are being added. Used for logging etc - - events_context (list[(EventBase, EventContext)]): - events and contexts which are being added to the room - - old_latest_event_ids (iterable[str]): - the old forward extremities for the room. - - new_latest_event_ids (iterable[str]): - the new forward extremities for the room. - - Returns: - Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]: - Returns a tuple of two state maps, the first being the full new current - state and the second being the delta to the existing current state. - If both are None then there has been no change. - - If there has been a change then we only return the delta if its - already been calculated. Conversely if we do know the delta then - the new current state is only returned if we've already calculated - it. - """ - # map from state_group to ((type, key) -> event_id) state map - state_groups_map = {} - - # Map from (prev state group, new state group) -> delta state dict - state_group_deltas = {} - - for ev, ctx in events_context: - if ctx.state_group is None: - # This should only happen for outlier events. - if not ev.internal_metadata.is_outlier(): - raise Exception( - "Context for new event %s has no state " - "group" % (ev.event_id,) - ) - continue - - if ctx.state_group in state_groups_map: - continue - - # We're only interested in pulling out state that has already - # been cached in the context. We'll pull stuff out of the DB later - # if necessary. - current_state_ids = ctx.get_cached_current_state_ids() - if current_state_ids is not None: - state_groups_map[ctx.state_group] = current_state_ids - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids - - # We need to map the event_ids to their state groups. First, let's - # check if the event is one we're persisting, in which case we can - # pull the state group from its context. - # Otherwise we need to pull the state group from the database. - - # Set of events we need to fetch groups for. (We know none of the old - # extremities are going to be in events_context). - missing_event_ids = set(old_latest_event_ids) - - event_id_to_state_group = {} - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding. - for ev, ctx in events_context: - if event_id == ev.event_id and ctx.state_group is not None: - event_id_to_state_group[event_id] = ctx.state_group - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - missing_event_ids.add(event_id) - - if missing_event_ids: - # Now pull out the state groups for any missing events from DB - event_to_groups = yield self._get_state_group_for_events(missing_event_ids) - event_id_to_state_group.update(event_to_groups) - - # State groups of old_latest_event_ids - old_state_groups = set( - event_id_to_state_group[evid] for evid in old_latest_event_ids - ) - - # State groups of new_latest_event_ids - new_state_groups = set( - event_id_to_state_group[evid] for evid in new_latest_event_ids - ) - - # If they old and new groups are the same then we don't need to do - # anything. - if old_state_groups == new_state_groups: - return None, None - - if len(new_state_groups) == 1 and len(old_state_groups) == 1: - # If we're going from one state group to another, lets check if - # we have a delta for that transition. If we do then we can just - # return that. - - new_state_group = next(iter(new_state_groups)) - old_state_group = next(iter(old_state_groups)) - - delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) - if delta_ids is not None: - # We have a delta from the existing to new current state, - # so lets just return that. If we happen to already have - # the current state in memory then lets also return that, - # but it doesn't matter if we don't. - new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids - - # Now that we have calculated new_state_groups we need to get - # their state IDs so we can resolve to a single state set. - missing_state = new_state_groups - set(state_groups_map) - if missing_state: - group_to_state = yield self._get_state_for_groups(missing_state) - state_groups_map.update(group_to_state) - - if len(new_state_groups) == 1: - # If there is only one state group, then we know what the current - # state is. - return state_groups_map[new_state_groups.pop()], None - - # Ok, we need to defer to the state handler to resolve our state sets. - - state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} - - events_map = {ev.event_id: ev for ev, _ in events_context} - - # We need to get the room version, which is in the create event. - # Normally that'd be in the database, but its also possible that we're - # currently trying to persist it. - room_version = None - for ev, _ in events_context: - if ev.type == EventTypes.Create and ev.state_key == "": - room_version = ev.content.get("room_version", "1") - break - - if not room_version: - room_version = yield self.get_room_version(room_id) - - logger.debug("calling resolve_state_groups from preserve_events") - res = yield self._state_resolution_handler.resolve_state_groups( - room_id, - room_version, - state_groups, - events_map, - state_res_store=StateResolutionStore(self), - ) - - return res.state, None - - @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, current_state): - """Calculate the new state deltas for a room. - - Assumes that we are only persisting events for one room at a time. - - Returns: - tuple[list, dict] (to_delete, to_insert): where to_delete are the - type/state_keys to remove from current_state_events and `to_insert` - are the updates to current_state_events. - """ - existing_state = yield self.get_current_state_ids(room_id) - - to_delete = [key for key in existing_state if key not in current_state] - - to_insert = { - key: ev_id - for key, ev_id in iteritems(current_state) - if ev_id != existing_state.get(key) - } - - return to_delete, to_insert - - @log_function - def _persist_events_txn( - self, - txn, - events_and_contexts, - backfilled, - delete_existing=False, - state_delta_for_room={}, - new_forward_extremeties={}, - ): - """Insert some number of room events into the necessary database tables. - - Rejected events are only inserted into the events table, the events_json table, - and the rejections table. Things reading from those table will need to check - whether the event was rejected. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): - events to persist - backfilled (bool): True if the events were backfilled - delete_existing (bool): True to purge existing table rows for the - events from the database. This is useful when retrying due to - IntegrityError. - state_delta_for_room (dict[str, (list, dict)]): - The current-state delta for each room. For each room, a tuple - (to_delete, to_insert), being a list of type/state keys to be - removed from the current state, and a state set to be added to - the current state. - new_forward_extremeties (dict[str, list[str]]): - The new forward extremities for each room. For each room, a - list of the event ids which are the forward extremities. - - """ - all_events_and_contexts = events_and_contexts - - min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering - max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - - self._update_forward_extremities_txn( - txn, - new_forward_extremities=new_forward_extremeties, - max_stream_order=max_stream_order, - ) - - # Ensure that we don't have the same event twice. - events_and_contexts = self._filter_events_and_contexts_for_duplicates( - events_and_contexts - ) - - self._update_room_depths_txn( - txn, events_and_contexts=events_and_contexts, backfilled=backfilled - ) - - # _update_outliers_txn filters out any events which have already been - # persisted, and returns the filtered list. - events_and_contexts = self._update_outliers_txn( - txn, events_and_contexts=events_and_contexts - ) - - # From this point onwards the events are only events that we haven't - # seen before. - - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - - self._store_event_txn(txn, events_and_contexts=events_and_contexts) - - # Insert into event_to_state_groups. - self._store_event_state_mappings_txn(txn, events_and_contexts) - - # We want to store event_auth mappings for rejected events, as they're - # used in state res v2. - # This is only necessary if the rejected event appears in an accepted - # event's auth chain, but its easier for now just to store them (and - # it doesn't take much storage compared to storing the entire event - # anyway). - self._simple_insert_many_txn( - txn, - table="event_auth", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } - for event, _ in events_and_contexts - for auth_id in event.auth_event_ids() - if event.is_state() - ], - ) - - # _store_rejected_events_txn filters out any events which were - # rejected, and returns the filtered list. - events_and_contexts = self._store_rejected_events_txn( - txn, events_and_contexts=events_and_contexts - ) - - # From this point onwards the events are only ones that weren't - # rejected. - - self._update_metadata_tables_txn( - txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - backfilled=backfilled, - ) - - # We call this last as it assumes we've inserted the events into - # room_memberships, where applicable. - self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) - - def _update_current_state_txn(self, txn, state_delta_by_room, stream_id): - for room_id, current_state_tuple in iteritems(state_delta_by_room): - to_delete, to_insert = current_state_tuple - - # First we add entries to the current_state_delta_stream. We - # do this before updating the current_state_events table so - # that we can use it to calculate the `prev_event_id`. (This - # allows us to not have to pull out the existing state - # unnecessarily). - # - # The stream_id for the update is chosen to be the minimum of the stream_ids - # for the batch of the events that we are persisting; that means we do not - # end up in a situation where workers see events before the - # current_state_delta updates. - # - sql = """ - INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, ?, ?, ?, ?, ( - SELECT event_id FROM current_state_events - WHERE room_id = ? AND type = ? AND state_key = ? - ) - """ - txn.executemany( - sql, - ( - ( - stream_id, - room_id, - etype, - state_key, - None, - room_id, - etype, - state_key, - ) - for etype, state_key in to_delete - # We sanity check that we're deleting rather than updating - if (etype, state_key) not in to_insert - ), - ) - txn.executemany( - sql, - ( - ( - stream_id, - room_id, - etype, - state_key, - ev_id, - room_id, - etype, - state_key, - ) - for (etype, state_key), ev_id in iteritems(to_insert) - ), - ) - - # Now we actually update the current_state_events table - - txn.executemany( - "DELETE FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?", - ( - (room_id, etype, state_key) - for etype, state_key in itertools.chain(to_delete, to_insert) - ), - ) - - # We include the membership in the current state table, hence we do - # a lookup when we insert. This assumes that all events have already - # been inserted into room_memberships. - txn.executemany( - """INSERT INTO current_state_events - (room_id, type, state_key, event_id, membership) - VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) - """, - [ - (room_id, key[0], key[1], ev_id, ev_id) - for key, ev_id in iteritems(to_insert) - ], - ) - - txn.call_after( - self._curr_state_delta_stream_cache.entity_has_changed, - room_id, - stream_id, - ) - - # Invalidate the various caches - - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invlidate the caches for all - # those users. - members_changed = set( - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - ) - - for member in members_changed: - txn.call_after( - self.get_rooms_for_user_with_stream_ordering.invalidate, (member,) - ) - - self._invalidate_state_caches_and_stream(txn, room_id, members_changed) - - def _update_forward_extremities_txn( - self, txn, new_forward_extremities, max_stream_order - ): - for room_id, new_extrem in iteritems(new_forward_extremities): - self._simple_delete_txn( - txn, table="event_forward_extremities", keyvalues={"room_id": room_id} - ) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - - self._simple_insert_many_txn( - txn, - table="event_forward_extremities", - values=[ - {"event_id": ev_id, "room_id": room_id} - for room_id, new_extrem in iteritems(new_forward_extremities) - for ev_id in new_extrem - ], - ) - # We now insert into stream_ordering_to_exterm a mapping from room_id, - # new stream_ordering to new forward extremeties in the room. - # This allows us to later efficiently look up the forward extremeties - # for a room before a given stream_ordering - self._simple_insert_many_txn( - txn, - table="stream_ordering_to_exterm", - values=[ - { - "room_id": room_id, - "event_id": event_id, - "stream_ordering": max_stream_order, - } - for room_id, new_extrem in iteritems(new_forward_extremities) - for event_id in new_extrem - ], - ) - - @classmethod - def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): - """Ensure that we don't have the same event twice. - - Pick the earliest non-outlier if there is one, else the earliest one. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): - Returns: - list[(EventBase, EventContext)]: filtered list - """ - new_events_and_contexts = OrderedDict() - for event, context in events_and_contexts: - prev_event_context = new_events_and_contexts.get(event.event_id) - if prev_event_context: - if not event.internal_metadata.is_outlier(): - if prev_event_context[0].internal_metadata.is_outlier(): - # To ensure correct ordering we pop, as OrderedDict is - # ordered by first insertion. - new_events_and_contexts.pop(event.event_id, None) - new_events_and_contexts[event.event_id] = (event, context) - else: - new_events_and_contexts[event.event_id] = (event, context) - return list(new_events_and_contexts.values()) - - def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): - """Update min_depth for each room - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - backfilled (bool): True if the events were backfilled - """ - depth_updates = {} - for event, context in events_and_contexts: - # Remove the any existing cache entries for the event_ids - txn.call_after(self._invalidate_get_event_cache, event.event_id) - if not backfilled: - txn.call_after( - self._events_stream_cache.entity_has_changed, - event.room_id, - event.internal_metadata.stream_ordering, - ) - - if not event.internal_metadata.is_outlier() and not context.rejected: - depth_updates[event.room_id] = max( - event.depth, depth_updates.get(event.room_id, event.depth) - ) - - for room_id, depth in iteritems(depth_updates): - self._update_min_depth_for_room_txn(txn, room_id, depth) - - def _update_outliers_txn(self, txn, events_and_contexts): - """Update any outliers with new event info. - - This turns outliers into ex-outliers (unless the new event was - rejected). - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - - Returns: - list[(EventBase, EventContext)] new list, without events which - are already in the events table. - """ - txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" - % (",".join(["?"] * len(events_and_contexts)),), - [event.event_id for event, _ in events_and_contexts], - ) - - have_persisted = {event_id: outlier for event_id, outlier in txn} - - to_remove = set() - for event, context in events_and_contexts: - if event.event_id not in have_persisted: - continue - - to_remove.add(event) - - if context.rejected: - # If the event is rejected then we don't care if the event - # was an outlier or not. - continue - - outlier_persisted = have_persisted[event.event_id] - if not event.internal_metadata.is_outlier() and outlier_persisted: - # We received a copy of an event that we had already stored as - # an outlier in the database. We now have some state at that - # so we need to update the state_groups table with that state. - - # insert into event_to_state_groups. - try: - self._store_event_state_mappings_txn(txn, ((event, context),)) - except Exception: - logger.exception("") - raise - - metadata_json = encode_json(event.internal_metadata.get_dict()) - - sql = ( - "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" - ) - txn.execute(sql, (metadata_json, event.event_id)) - - # Add an entry to the ex_outlier_stream table to replicate the - # change in outlier status to our workers. - stream_order = event.internal_metadata.stream_ordering - state_group_id = context.state_group - self._simple_insert_txn( - txn, - table="ex_outlier_stream", - values={ - "event_stream_ordering": stream_order, - "event_id": event.event_id, - "state_group": state_group_id, - }, - ) - - sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?" - txn.execute(sql, (False, event.event_id)) - - # Update the event_backward_extremities table now that this - # event isn't an outlier any more. - self._update_backward_extremeties(txn, [event]) - - return [ec for ec in events_and_contexts if ec[0] not in to_remove] - - @classmethod - def _delete_existing_rows_txn(cls, txn, events_and_contexts): - if not events_and_contexts: - # nothing to do here - return - - logger.info("Deleting existing") - - for table in ( - "events", - "event_auth", - "event_json", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_to_state_groups", - "local_invites", - "state_events", - "rejections", - "redactions", - "room_memberships", - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts], - ) - - for table in ("event_push_actions",): - txn.executemany( - "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], - ) - - def _store_event_txn(self, txn, events_and_contexts): - """Insert new events into the event and event_json tables - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - """ - - if not events_and_contexts: - # nothing to do here - return - - def event_dict(event): - d = event.get_dict() - d.pop("redacted", None) - d.pop("redacted_because", None) - return d - - self._simple_insert_many_txn( - txn, - table="event_json", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": encode_json( - event.internal_metadata.get_dict() - ), - "json": encode_json(event_dict(event)), - "format_version": event.format_version, - } - for event, _ in events_and_contexts - ], - ) - - self._simple_insert_many_txn( - txn, - table="events", - values=[ - { - "stream_ordering": event.internal_metadata.stream_ordering, - "topological_ordering": event.depth, - "depth": event.depth, - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "processed": True, - "outlier": event.internal_metadata.is_outlier(), - "origin_server_ts": int(event.origin_server_ts), - "received_ts": self._clock.time_msec(), - "sender": event.sender, - "contains_url": ( - "url" in event.content - and isinstance(event.content["url"], text_type) - ), - } - for event, _ in events_and_contexts - ], - ) - - def _store_rejected_events_txn(self, txn, events_and_contexts): - """Add rows to the 'rejections' table for received events which were - rejected - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - - Returns: - list[(EventBase, EventContext)] new list, without the rejected - events. - """ - # Remove the rejected events from the list now that we've added them - # to the events table and the events_json table. - to_remove = set() - for event, context in events_and_contexts: - if context.rejected: - # Insert the event_id into the rejections table - self._store_rejections_txn(txn, event.event_id, context.rejected) - to_remove.add(event) - - return [ec for ec in events_and_contexts if ec[0] not in to_remove] - - def _update_metadata_tables_txn( - self, txn, events_and_contexts, all_events_and_contexts, backfilled - ): - """Update all the miscellaneous tables for new events - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - backfilled (bool): True if the events were backfilled - """ - - # Insert all the push actions into the event_push_actions table. - self._set_push_actions_for_event_and_users_txn( - txn, - events_and_contexts=events_and_contexts, - all_events_and_contexts=all_events_and_contexts, - ) - - if not events_and_contexts: - # nothing to do here - return - - for event, context in events_and_contexts: - if event.type == EventTypes.Redaction and event.redacts is not None: - # Remove the entries in the event_push_actions table for the - # redacted event. - self._remove_push_actions_for_event_id_txn( - txn, event.room_id, event.redacts - ) - - # Remove from relations table. - self._handle_redaction(txn, event.redacts) - - # Update the event_forward_extremities, event_backward_extremities and - # event_edges tables. - self._handle_mult_prev_events( - txn, events=[event for event, _ in events_and_contexts] - ) - - for event, _ in events_and_contexts: - if event.type == EventTypes.Name: - # Insert into the event_search table. - self._store_room_name_txn(txn, event) - elif event.type == EventTypes.Topic: - # Insert into the event_search table. - self._store_room_topic_txn(txn, event) - elif event.type == EventTypes.Message: - # Insert into the event_search table. - self._store_room_message_txn(txn, event) - elif event.type == EventTypes.Redaction: - # Insert into the redactions table. - self._store_redaction(txn, event) - - self._handle_event_relations(txn, event) - - # Insert into the room_memberships table. - self._store_room_members_txn( - txn, - [ - event - for event, _ in events_and_contexts - if event.type == EventTypes.Member - ], - backfilled=backfilled, - ) - - # Insert event_reference_hashes table. - self._store_event_reference_hashes_txn( - txn, [event for event, _ in events_and_contexts] - ) - - state_events_and_contexts = [ - ec for ec in events_and_contexts if ec[0].is_state() - ] - - state_values = [] - for event, context in state_events_and_contexts: - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - state_values.append(vals) - - self._simple_insert_many_txn(txn, table="state_events", values=state_values) - - # Prefill the event cache - self._add_to_cache(txn, events_and_contexts) - - def _add_to_cache(self, txn, events_and_contexts): - to_prefill = [] - - rows = [] - N = 200 - for i in range(0, len(events_and_contexts), N): - ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} - if not ev_map: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM events as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(ev_map)),) - - txn.execute(sql, list(ev_map)) - rows = self.cursor_to_dict(txn) - for row in rows: - event = ev_map[row["event_id"]] - if not row["rejects"] and not row["redacts"]: - to_prefill.append( - _EventCacheEntry(event=event, redacted_event=None) - ) - - def prefill(): - for cache_entry in to_prefill: - self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) - - txn.call_after(prefill) - - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event - txn.call_after(self._invalidate_get_event_cache, event.redacts) - txn.execute( - "INSERT INTO redactions (event_id, redacts) VALUES (?,?)", - (event.event_id, event.redacts), - ) - - @defer.inlineCallbacks - def _censor_redactions(self): - """Censors all redactions older than the configured period that haven't - been censored yet. - - By censor we mean update the event_json table with the redacted event. - - Returns: - Deferred - """ - - if self.hs.config.redaction_retention_period is None: - return - - max_pos = yield self.find_first_stream_ordering_after_ts( - self._clock.time_msec() - self.hs.config.redaction_retention_period - ) - - # We fetch all redactions that: - # 1. point to an event we have, - # 2. has a stream ordering from before the cut off, and - # 3. we haven't yet censored. - # - # This is limited to 100 events to ensure that we don't try and do too - # much at once. We'll get called again so this should eventually catch - # up. - # - # We use the range [-max_pos, max_pos] to handle backfilled events, - # which are given negative stream ordering. - sql = """ - SELECT redact_event.event_id, redacts FROM redactions - INNER JOIN events AS redact_event USING (event_id) - INNER JOIN events AS original_event ON ( - redact_event.room_id = original_event.room_id - AND redacts = original_event.event_id - ) - WHERE NOT have_censored - AND ? <= redact_event.stream_ordering AND redact_event.stream_ordering <= ? - ORDER BY redact_event.stream_ordering ASC - LIMIT ? - """ - - rows = yield self._execute( - "_censor_redactions_fetch", None, sql, -max_pos, max_pos, 100 - ) - - updates = [] - - for redaction_id, event_id in rows: - redaction_event = yield self.get_event(redaction_id, allow_none=True) - original_event = yield self.get_event( - event_id, allow_rejected=True, allow_none=True - ) - - # The SQL above ensures that we have both the redaction and - # original event, so if the `get_event` calls return None it - # means that the redaction wasn't allowed. Either way we know that - # the result won't change so we mark the fact that we've checked. - if ( - redaction_event - and original_event - and original_event.internal_metadata.is_redacted() - ): - # Redaction was allowed - pruned_json = encode_canonical_json( - prune_event_dict(original_event.get_dict()) - ) - else: - # Redaction wasn't allowed - pruned_json = None - - updates.append((redaction_id, event_id, pruned_json)) - - def _update_censor_txn(txn): - for redaction_id, event_id, pruned_json in updates: - if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) - - self._simple_update_one_txn( - txn, - table="redactions", - keyvalues={"event_id": redaction_id}, - updatevalues={"have_censored": True}, - ) - - yield self.runInteraction("_update_censor_txn", _update_censor_txn) - - @defer.inlineCallbacks - def count_daily_messages(self): - """ - Returns an estimate of the number of messages sent in the last day. - - If it has been significantly less or more than one day since the last - call to this function, it will return None. - """ - - def _count_messages(txn): - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() - return count - - ret = yield self.runInteraction("count_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_sent_messages(self): - def _count_messages(txn): - # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. - like_clause = "%:" + self.hs.hostname - - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND sender LIKE ? - AND stream_ordering > ? - """ - - txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - count, = txn.fetchone() - return count - - ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_active_rooms(self): - def _count(txn): - sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() - return count - - ret = yield self.runInteraction("count_daily_active_rooms", _count) - return ret - - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.runInteraction("get_all_new_events", get_all_new_events_txn) - - def purge_history(self, room_id, token, delete_local_events): - """Deletes room history before a certain point - - Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): - if True, we will delete local events as well as remote ones - (instead of just marking them as outliers and deleting their - state groups). - """ - - return self.runInteraction( - "purge_history", - self._purge_history_txn, - room_id, - token, - delete_local_events, - ) - - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - - # Tables that should be pruned: - # event_auth - # event_backward_extremities - # event_edges - # event_forward_extremities - # event_json - # event_push_actions - # event_reference_hashes - # event_search - # event_to_state_groups - # events - # rejections - # room_depth - # state_groups - # state_groups_state - - # we will build a temporary table listing the events so that we don't - # have to keep shovelling the list back and forth across the - # connection. Annoyingly the python sqlite driver commits the - # transaction on CREATE, so let's do this first. - # - # furthermore, we might already have the table from a previous (failed) - # purge attempt, so let's drop the table first. - - txn.execute("DROP TABLE IF EXISTS events_to_purge") - - txn.execute( - "CREATE TEMPORARY TABLE events_to_purge (" - " event_id TEXT NOT NULL," - " should_delete BOOLEAN NOT NULL" - ")" - ) - - # First ensure that we're not about to delete all the forward extremeties - txn.execute( - "SELECT e.event_id, e.depth FROM events as e " - "INNER JOIN event_forward_extremities as f " - "ON e.event_id = f.event_id " - "AND e.room_id = f.room_id " - "WHERE f.room_id = ?", - (room_id,), - ) - rows = txn.fetchall() - max_depth = max(row[1] for row in rows) - - if max_depth < token.topological: - # We need to ensure we don't delete all the events from the database - # otherwise we wouldn't be able to send any events (due to not - # having any backwards extremeties) - raise SynapseError( - 400, "topological_ordering is greater than forward extremeties" - ) - - logger.info("[purge] looking for events to delete") - - should_delete_expr = "state_key IS NULL" - should_delete_params = () - if not delete_local_events: - should_delete_expr += " AND event_id NOT LIKE ?" - - # We include the parameter twice since we use the expression twice - should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) - - should_delete_params += (room_id, token.topological) - - # Note that we insert events that are outliers and aren't going to be - # deleted, as nothing will happen to them. - txn.execute( - "INSERT INTO events_to_purge" - " SELECT event_id, %s" - " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" - % (should_delete_expr, should_delete_expr), - should_delete_params, - ) - - # We create the indices *after* insertion as that's a lot faster. - - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)" - ) - - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)") - - txn.execute("SELECT event_id, should_delete FROM events_to_purge") - event_rows = txn.fetchall() - logger.info( - "[purge] found %i events before cutoff, of which %i can be deleted", - len(event_rows), - sum(1 for e in event_rows if e[1]), - ) - - logger.info("[purge] Finding new backward extremities") - - # We calculate the new entries for the backward extremeties by finding - # events to be purged that are pointed to by events we're not going to - # purge. - txn.execute( - "SELECT DISTINCT e.event_id FROM events_to_purge AS e" - " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" - " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" - " WHERE ep2.event_id IS NULL" - ) - new_backwards_extrems = txn.fetchall() - - logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) - - txn.execute( - "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) - ) - - # Update backward extremeties - txn.executemany( - "INSERT INTO event_backward_extremities (room_id, event_id)" - " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], - ) - - logger.info("[purge] finding redundant state groups") - - # Get all state groups that are referenced by events that are to be - # deleted. We then go and check if they are referenced by other events - # or state groups, and if not we delete them. - txn.execute( - """ - SELECT DISTINCT state_group FROM events_to_purge - INNER JOIN event_to_state_groups USING (event_id) - """ - ) - - referenced_state_groups = set(sg for sg, in txn) - logger.info( - "[purge] found %i referenced state groups", len(referenced_state_groups) - ) - - logger.info("[purge] finding state groups that can be deleted") - - _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups) - state_groups_to_delete, remaining_state_groups = _ - - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self._simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self._simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - logger.info("[purge] removing events from event_to_state_groups") - txn.execute( - "DELETE FROM event_to_state_groups " - "WHERE event_id IN (SELECT event_id from events_to_purge)" - ) - for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) - - # Delete all remote non-state events - for table in ( - "events", - "event_json", - "event_auth", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "rejections", - ): - logger.info("[purge] removing events from %s", table) - - txn.execute( - "DELETE FROM %s WHERE event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,) - ) - - # event_push_actions lacks an index on event_id, and has one on - # (room_id, event_id) instead. - for table in ("event_push_actions",): - logger.info("[purge] removing events from %s", table) - - txn.execute( - "DELETE FROM %s WHERE room_id = ? AND event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,), - (room_id,), - ) - - # Mark all state and own events as outliers - logger.info("[purge] marking remaining events as outliers") - txn.execute( - "UPDATE events SET outlier = ?" - " WHERE event_id IN (" - " SELECT event_id FROM events_to_purge " - " WHERE NOT should_delete" - ")", - (True,), - ) - - # synapse tries to take out an exclusive lock on room_depth whenever it - # persists events (because upsert), and once we run this update, we - # will block that for the rest of our transaction. - # - # So, let's stick it at the end so that we don't block event - # persistence. - # - # We do this by calculating the minimum depth of the backwards - # extremities. However, the events in event_backward_extremities - # are ones we don't have yet so we need to look at the events that - # point to it via event_edges table. - txn.execute( - """ - SELECT COALESCE(MIN(depth), 0) - FROM event_backward_extremities AS eb - INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id - INNER JOIN events AS e ON e.event_id = eg.event_id - WHERE eb.room_id = ? - """, - (room_id,), - ) - min_depth, = txn.fetchone() - - logger.info("[purge] updating room_depth to %d", min_depth) - - txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (min_depth, room_id), - ) - - # finally, drop the temp table. this will commit the txn in sqlite, - # so make sure to keep this actually last. - txn.execute("DROP TABLE events_to_purge") - - logger.info("[purge] done") - - def _find_unreferenced_groups_during_purge(self, txn, state_groups): - """Used when purging history to figure out which state groups can be - deleted and which need to be de-delta'ed (due to one of its prev groups - being scheduled for deletion). - - Args: - txn - state_groups (set[int]): Set of state groups referenced by events - that are going to be deleted. - - Returns: - tuple[set[int], set[int]]: The set of state groups that can be - deleted and the set of state groups that need to be de-delta'ed - """ - # Graph of state group -> previous group - graph = {} - - # Set of events that we have found to be referenced by events - referenced_groups = set() - - # Set of state groups we've already seen - state_groups_seen = set(state_groups) - - # Set of state groups to handle next. - next_to_search = set(state_groups) - while next_to_search: - # We bound size of groups we're looking up at once, to stop the - # SQL query getting too big - if len(next_to_search) < 100: - current_search = next_to_search - next_to_search = set() - else: - current_search = set(itertools.islice(next_to_search, 100)) - next_to_search -= current_search - - # Check if state groups are referenced - sql = """ - SELECT DISTINCT state_group FROM event_to_state_groups - LEFT JOIN events_to_purge AS ep USING (event_id) - WHERE state_group IN (%s) AND ep.event_id IS NULL - """ % ( - ",".join("?" for _ in current_search), - ) - txn.execute(sql, list(current_search)) - - referenced = set(sg for sg, in txn) - referenced_groups |= referenced - - # We don't continue iterating up the state group graphs for state - # groups that are referenced. - current_search -= referenced - - rows = self._simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=current_search, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - ) - - prevs = set(row["state_group"] for row in rows) - # We don't bother re-handling groups we've already seen - prevs -= state_groups_seen - next_to_search |= prevs - state_groups_seen |= prevs - - for row in rows: - # Note: Each state group can have at most one prev group - graph[row["state_group"]] = row["prev_state_group"] - - to_delete = state_groups_seen - referenced_groups - - to_dedelta = set() - for sg in referenced_groups: - prev_sg = graph.get(sg) - if prev_sg and prev_sg in to_delete: - to_dedelta.add(sg) - - return to_delete, to_dedelta - - def purge_room(self, room_id): - """Deletes all record of a room - - Args: - room_id (str): - """ - - return self.runInteraction("purge_room", self._purge_room_txn, room_id) - - def _purge_room_txn(self, txn, room_id): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - txn.execute( - """ - DELETE FROM state_groups_state WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - txn.execute( - """ - DELETE FROM state_group_edges WHERE state_group IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - txn.execute( - """ - DELETE FROM state_groups WHERE id IN ( - SELECT state_group FROM events JOIN event_to_state_groups USING(event_id) - WHERE events.room_id=? - ) - """, - (room_id,), - ) - - # and then tables which lack an index on room_id but have one on event_id - for table in ( - "event_auth", - "event_edges", - "event_push_actions_staging", - "event_reference_hashes", - "event_relations", - "event_to_state_groups", - "redactions", - "rejections", - "state_events", - ): - logger.info("[purge] removing %s from %s", room_id, table) - - txn.execute( - """ - DELETE FROM %s WHERE event_id IN ( - SELECT event_id FROM events WHERE room_id=? - ) - """ - % (table,), - (room_id,), - ) - - # and finally, the tables with an index on room_id (or no useful index) - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "topics", - "users_in_public_rooms", - "users_who_share_private_rooms", - # no useful index, but let's clear them anyway - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - ): - logger.info("[purge] removing %s from %s", room_id, table) - txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) - - # Other tables we do NOT need to clear out: - # - # - blocked_rooms - # This is important, to make sure that we don't accidentally rejoin a blocked - # room after it was purged - # - # - user_directory - # This has a room_id column, but it is unused - # - - # Other tables that we might want to consider clearing out include: - # - # - event_reports - # Given that these are intended for abuse management my initial - # inclination is to leave them in place. - # - # - current_state_delta_stream - # - ex_outlier_stream - # - room_tags_revisions - # The problem with these is that they are largeish and there is no room_id - # index on them. In any case we should be clearing out 'stream' tables - # periodically anyway (#5888) - - # TODO: we could probably usefully do a bunch of cache invalidation here - - logger.info("[purge] done") - - @defer.inlineCallbacks - def is_event_after(self, event_id1, event_id2): - """Returns True if event_id1 is after event_id2 in the stream - """ - to_1, so_1 = yield self._get_event_ordering(event_id1) - to_2, so_2 = yield self._get_event_ordering(event_id2) - return (to_1, so_1) > (to_2, so_2) - - @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): - res = yield self._simple_select_one( - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - allow_none=True, - ) - - if not res: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - return (int(res["topological_ordering"]), int(res["stream_ordering"])) - - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e72f89e446..4769b21529 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -14,208 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging -import six - import attr -from signedjson.key import decode_verify_key_bytes - -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedList - -from ._base import SQLBaseStore logger = logging.getLogger(__name__) -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - @attr.s(slots=True, frozen=True) class FetchKeyResult(object): verify_key = attr.ib() # VerifyKey: the key itself valid_until_ts = attr.ib() # int: how long we can use this key for - - -class KeyStore(SQLBaseStore): - """Persistence for signature verification keys - """ - - @cached() - def _get_server_verify_key(self, server_name_and_key_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" - ) - def get_server_verify_keys(self, server_name_and_key_ids): - """ - Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown - """ - keys = {} - - def _get_keys(txn, batch): - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = ( - "SELECT server_name, key_id, verify_key, ts_valid_until_ms " - "FROM server_signature_keys WHERE 1=0" - ) + " OR (server_name=? AND key_id=?)" * len(batch) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - res = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - keys[(server_name, key_id)] = res - - def _txn(txn): - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return self.runInteraction("get_server_verify_keys", _txn) - - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): - """Stores NACL verification keys for remote servers. - Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). - """ - key_values = [] - value_values = [] - invalidations = [] - for server_name, key_id, fetch_result in verify_keys: - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) - ) - # invalidate takes a tuple corresponding to the params of - # _get_server_verify_key. _get_server_verify_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.runInteraction( - "store_server_verify_keys", - self._simple_upsert_many_txn, - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - ).addCallback(_invalidate) - - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. - """ - return self._simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", - ) - - def get_server_keys_json(self, server_keys): - """Retrive the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. - Args: - server_keys (list): List of (server_name, key_id, source) triplets. - Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts - """ - - def _get_server_keys_json_txn(txn): - results = {} - for server_name, key_id, from_server in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - results[(server_name, key_id, from_server)] = rows - return results - - return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py new file mode 100644 index 0000000000..f159400a87 --- /dev/null +++ b/synapse/storage/persist_events.py @@ -0,0 +1,794 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections import deque, namedtuple +from typing import Iterable, List, Optional, Set, Tuple + +from six import iteritems +from six.moves import range + +from prometheus_client import Counter, Histogram + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore +from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main.events import DeltaState +from synapse.types import StateMap +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + +# The number of times we are recalculating the current state +state_delta_counter = Counter("synapse_storage_events_state_delta", "") + +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = Counter( + "synapse_storage_events_state_delta_single_event", "" +) + +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = Counter( + "synapse_storage_events_state_delta_reuse_delta", "" +) + +# The number of forward extremities for each new event. +forward_extremities_counter = Histogram( + "synapse_storage_events_forward_extremities_persisted", + "Number of forward extremities for each new event", + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + +# The number of stale forward extremities for each new event. Stale extremities +# are those that were in the previous set of extremities as well as the new. +stale_forward_extremities_counter = Histogram( + "synapse_storage_events_stale_forward_extremities_persisted", + "Number of unchanged forward extremities for each new event", + buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + + +class _EventPeristenceQueue(object): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + _EventPersistQueueItem = namedtuple( + "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") + ) + + def __init__(self): + self._event_persist_queues = {} + self._currently_persisting_rooms = set() + + def add_to_queue(self, room_id, events_and_contexts, backfilled): + """Add events to the queue, with the given persist_event options. + + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. + end_item = queue[-1] + if end_item.backfilled == backfilled: + end_item.events_and_contexts.extend(events_and_contexts) + return end_item.deferred.observe() + + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + + queue.append( + self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + deferred=deferred, + ) + ) + + return deferred.observe() + + def handle_queue(self, room_id, per_item_callback): + """Attempts to handle the queue for a room if not already being handled. + + The given callback will be invoked with for each item in the queue, + of type _EventPersistQueueItem. The per_item_callback will continuously + be called with new items, unless the queue becomnes empty. The return + value of the function will be given to the deferreds waiting on the item, + exceptions will be passed to the deferreds as well. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + async def handle_queue_loop(): + try: + queue = self._get_drainining_queue(room_id) + for item in queue: + try: + ret = await per_item_callback(item) + except Exception: + with PreserveLoggingContext(): + item.deferred.errback() + else: + with PreserveLoggingContext(): + item.deferred.callback(ret) + finally: + queue = self._event_persist_queues.pop(room_id, None) + if queue: + self._event_persist_queues[room_id] = queue + self._currently_persisting_rooms.discard(room_id) + + # set handle_queue_loop off in the background + run_as_background_process("persist_events", handle_queue_loop) + + def _get_drainining_queue(self, room_id): + queue = self._event_persist_queues.setdefault(room_id, deque()) + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + +class EventsPersistenceStorage(object): + """High level interface for handling persisting newly received events. + + Takes care of batching up events by room, and calculating the necessary + current state and forward extremity changes. + """ + + def __init__(self, hs, stores: DataStores): + # We ultimately want to split out the state store from the main store, + # so we use separate variables here even though they point to the same + # store for now. + self.main_store = stores.main + self.state_store = stores.state + self.persist_events_store = stores.persist_events + + self._clock = hs.get_clock() + self.is_mine_id = hs.is_mine_id + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + + @defer.inlineCallbacks + def persist_events( + self, + events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + backfilled: bool = False, + ): + """ + Write events to the database + Args: + events_and_contexts: list of tuples of (event, context) + backfilled: Whether the results are retrieved from federation + via backfill or not. Used to determine if they're "new" events + which might update the current state etc. + + Returns: + Deferred[int]: the stream ordering of the latest persisted event + """ + partitioned = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + deferreds = [] + for room_id, evs_ctxs in iteritems(partitioned): + d = self._event_persist_queue.add_to_queue( + room_id, evs_ctxs, backfilled=backfilled + ) + deferreds.append(d) + + for room_id in partitioned: + self._maybe_start_persisting(room_id) + + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) + + max_persisted_id = yield self.main_store.get_current_events_token() + + return max_persisted_id + + @defer.inlineCallbacks + def persist_event( + self, event: FrozenEvent, context: EventContext, backfilled: bool = False + ): + """ + Returns: + Deferred[Tuple[int, int]]: the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ + deferred = self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], backfilled=backfilled + ) + + self._maybe_start_persisting(event.room_id) + + yield make_deferred_yieldable(deferred) + + max_persisted_id = yield self.main_store.get_current_events_token() + return (event.internal_metadata.stream_ordering, max_persisted_id) + + def _maybe_start_persisting(self, room_id: str): + async def persisting_queue(item): + with Measure(self._clock, "persist_events"): + await self._persist_events( + item.events_and_contexts, backfilled=item.backfilled + ) + + self._event_persist_queue.handle_queue(room_id, persisting_queue) + + async def _persist_events( + self, + events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + backfilled: bool = False, + ): + """Calculates the change to current state and forward extremities, and + persists the given events and with those updates. + """ + if not events_and_contexts: + return + + chunks = [ + events_and_contexts[x : x + 100] + for x in range(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room + new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events. + # This is simply used to prefill the get_current_state_ids + # cache + current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where to_delete is a list + # of type/state keys to remove from current state, and to_insert + # is a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + + # Set of remote users which were in rooms the server has left. We + # should check if we still share any rooms and if not we mark their + # device lists as stale. + potentially_left_users = set() # type: Set[str] + + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in iteritems(events_by_room): + latest_event_ids = await self.main_store.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = await self._calculate_new_extremities( + room_id, ev_ctx_rm, latest_event_ids + ) + + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: + # No change in extremities, so no change in state + continue + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids + + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_event_ids()) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(ev.prev_event_ids()) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.debug("Calculating state delta for room %s", room_id) + with Measure( + self._clock, "persist_events.get_new_state_after_events" + ): + res = await self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + current_state, delta_ids = res + + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + delta = None + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + delta = DeltaState([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, "persist_events.calculate_state_delta" + ): + delta = await self._calculate_state_delta( + room_id, current_state + ) + + if delta: + # If we have a change of state then lets check + # whether we're actually still a member of the room, + # or if our last user left. If we're no longer in + # the room then we delete the current state and + # extremities. + is_still_joined = await self._is_server_still_joined( + room_id, + ev_ctx_rm, + delta, + current_state, + potentially_left_users, + ) + if not is_still_joined: + logger.info("Server no longer in room %s", room_id) + latest_event_ids = [] + current_state = {} + delta.no_longer_in_room = True + + state_delta_for_room[room_id] = delta + + # If we have the current_state then lets prefill + # the cache with it. + if current_state is not None: + current_state_for_room[room_id] = current_state + + await self.persist_events_store._persist_events_and_state_updates( + chunk, + current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, + new_forward_extremeties=new_forward_extremeties, + backfilled=backfilled, + ) + + await self._handle_potentially_left_users(potentially_left_users) + + async def _calculate_new_extremities( + self, + room_id: str, + event_contexts: List[Tuple[FrozenEvent, EventContext]], + latest_event_ids: List[str], + ): + """Calculates the new forward extremities for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected + and not event.internal_metadata.is_soft_failed() + ] + + latest_event_ids = set(latest_event_ids) + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update(event.event_id for event in new_events) + + # Now remove all events which are prev_events of any of the new events + result.difference_update( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + + # Remove any events which are prev_events of any existing events. + existing_prevs = await self.persist_events_store._get_events_which_are_prevs( + result + ) + result.difference_update(existing_prevs) + + # Finally handle the case where the new events have soft-failed prev + # events. If they do we need to remove them and their prev events, + # otherwise we end up with dangling extremities. + existing_prevs = await self.persist_events_store._get_prevs_before_rejected( + e_id for event in new_events for e_id in event.prev_event_ids() + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + # (e.g. we ignore backfill/outliers/etc) + if result != latest_event_ids: + forward_extremities_counter.observe(len(result)) + stale = latest_event_ids & result + stale_forward_extremities_counter.observe(len(stale)) + + return result + + async def _get_new_state_after_events( + self, + room_id: str, + events_context: List[Tuple[FrozenEvent, EventContext]], + old_latest_event_ids: Iterable[str], + new_latest_event_ids: Iterable[str], + ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: + """Calculate the current state dict after adding some new events to + a room + + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. + + Returns: + Returns a tuple of two state maps, the first being the full new current + state and the second being the delta to the existing current state. + If both are None then there has been no change. + + If there has been a change then we only return the delta if its + already been calculated. Conversely if we do know the delta then + the new current state is only returned if we've already calculated + it. + """ + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + + # Map from (prev state group, new state group) -> delta state dict + state_group_deltas = {} + + for ev, ctx in events_context: + if ctx.state_group is None: + # This should only happen for outlier events. + if not ev.internal_metadata.is_outlier(): + raise Exception( + "Context for new event %s has no state " + "group" % (ev.event_id,) + ) + continue + + if ctx.state_group in state_groups_map: + continue + + # We're only interested in pulling out state that has already + # been cached in the context. We'll pull stuff out of the DB later + # if necessary. + current_state_ids = ctx.get_cached_current_state_ids() + if current_state_ids is not None: + state_groups_map[ctx.state_group] = current_state_ids + + if ctx.prev_group: + state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding. + for ev, ctx in events_context: + if event_id == ev.event_id and ctx.state_group is not None: + event_id_to_state_group[event_id] = ctx.state_group + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + missing_event_ids.add(event_id) + + if missing_event_ids: + # Now pull out the state groups for any missing events from DB + event_to_groups = await self.main_store._get_state_group_for_events( + missing_event_ids + ) + event_id_to_state_group.update(event_to_groups) + + # State groups of old_latest_event_ids + old_state_groups = { + event_id_to_state_group[evid] for evid in old_latest_event_ids + } + + # State groups of new_latest_event_ids + new_state_groups = { + event_id_to_state_group[evid] for evid in new_latest_event_ids + } + + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return None, None + + if len(new_state_groups) == 1 and len(old_state_groups) == 1: + # If we're going from one state group to another, lets check if + # we have a delta for that transition. If we do then we can just + # return that. + + new_state_group = next(iter(new_state_groups)) + old_state_group = next(iter(old_state_groups)) + + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) + if delta_ids is not None: + # We have a delta from the existing to new current state, + # so lets just return that. If we happen to already have + # the current state in memory then lets also return that, + # but it doesn't matter if we don't. + new_state = state_groups_map.get(new_state_group) + return new_state, delta_ids + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = await self.state_store._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + return state_groups_map[new_state_groups.pop()], None + + # Ok, we need to defer to the state handler to resolve our state sets. + + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} + + events_map = {ev.event_id: ev for ev, _ in events_context} + + # We need to get the room version, which is in the create event. + # Normally that'd be in the database, but its also possible that we're + # currently trying to persist it. + room_version = None + for ev, _ in events_context: + if ev.type == EventTypes.Create and ev.state_key == "": + room_version = ev.content.get("room_version", "1") + break + + if not room_version: + room_version = await self.main_store.get_room_version_id(room_id) + + logger.debug("calling resolve_state_groups from preserve_events") + res = await self._state_resolution_handler.resolve_state_groups( + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self.main_store), + ) + + return res.state, None + + async def _calculate_state_delta( + self, room_id: str, current_state: StateMap[str] + ) -> DeltaState: + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + """ + existing_state = await self.main_store.get_current_state_ids(room_id) + + to_delete = [key for key in existing_state if key not in current_state] + + to_insert = { + key: ev_id + for key, ev_id in iteritems(current_state) + if ev_id != existing_state.get(key) + } + + return DeltaState(to_delete=to_delete, to_insert=to_insert) + + async def _is_server_still_joined( + self, + room_id: str, + ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], + delta: DeltaState, + current_state: Optional[StateMap[str]], + potentially_left_users: Set[str], + ) -> bool: + """Check if the server will still be joined after the given events have + been persised. + + Args: + room_id + ev_ctx_rm + delta: The delta of current state between what is in the database + and what the new current state will be. + current_state: The new current state if it already been calculated, + otherwise None. + potentially_left_users: If the server has left the room, then joined + remote users will be added to this set to indicate that the + server may no longer be sharing a room with them. + """ + + if not any( + self.is_mine_id(state_key) + for typ, state_key in itertools.chain(delta.to_delete, delta.to_insert) + if typ == EventTypes.Member + ): + # There have been no changes to membership of our users, so nothing + # has changed and we assume we're still in the room. + return True + + # Check if any of the given events are a local join that appear in the + # current state + events_to_check = [] # Event IDs that aren't an event we're persisting + for (typ, state_key), event_id in delta.to_insert.items(): + if typ != EventTypes.Member or not self.is_mine_id(state_key): + continue + + for event, _ in ev_ctx_rm: + if event_id == event.event_id: + if event.membership == Membership.JOIN: + return True + + # The event is not in `ev_ctx_rm`, so we need to pull it out of + # the DB. + events_to_check.append(event_id) + + # Check if any of the changes that we don't have events for are joins. + if events_to_check: + rows = await self.main_store.get_membership_from_event_ids(events_to_check) + is_still_joined = any(row["membership"] == Membership.JOIN for row in rows) + if is_still_joined: + return True + + # None of the new state events are local joins, so we check the database + # to see if there are any other local users in the room. We ignore users + # whose state has changed as we've already their new state above. + users_to_ignore = [ + state_key + for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete) + if typ == EventTypes.Member and self.is_mine_id(state_key) + ] + + if await self.main_store.is_local_host_in_room_ignoring_users( + room_id, users_to_ignore + ): + return True + + # The server will leave the room, so we go and find out which remote + # users will still be joined when we leave. + if current_state is None: + current_state = await self.main_store.get_current_state_ids(room_id) + current_state = dict(current_state) + for key in delta.to_delete: + current_state.pop(key, None) + + current_state.update(delta.to_insert) + + remote_event_ids = [ + event_id + for (typ, state_key,), event_id in current_state.items() + if typ == EventTypes.Member and not self.is_mine_id(state_key) + ] + rows = await self.main_store.get_membership_from_event_ids(remote_event_ids) + potentially_left_users.update( + row["user_id"] for row in rows if row["membership"] == Membership.JOIN + ) + + return False + + async def _handle_potentially_left_users(self, user_ids: Set[str]): + """Given a set of remote users check if the server still shares a room with + them. If not then mark those users' device cache as stale. + """ + + if not user_ids: + return + + joined_users = await self.main_store.get_users_server_still_shares_room_with( + user_ids + ) + left_users = user_ids - joined_users + + for user_id in left_users: + await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + return await self.persist_events_store.locally_reject_invite(user_id, room_id) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e96eed8a6d..9cc3b51fe6 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -14,20 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import fnmatch import imp import logging import os import re +from collections import Counter +from typing import TextIO + +import attr from synapse.storage.engines.postgres import PostgresEngine +from synapse.storage.types import Cursor logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 56 +# XXX: If you're about to bump this to 59 (or higher) please create an update +# that drops the unused `cache_invalidation_stream` table, as per #7436! +# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656! +SCHEMA_VERSION = 58 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -40,7 +47,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -53,7 +60,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ + try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) @@ -65,13 +75,22 @@ def prepare_database(db_conn, database_engine, config): if user_version != SCHEMA_VERSION: # If we don't pass in a config file then we are expecting to # have already upgraded the DB. - raise UpgradeDatabaseException("Database needs to be upgraded") + raise UpgradeDatabaseException( + "Expected database schema version %i but got %i" + % (SCHEMA_VERSION, user_version) + ) else: _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine, config + cur, + user_version, + delta_files, + upgraded, + database_engine, + config, + data_stores=data_stores, ) else: - _setup_new_database(cur, database_engine) + _setup_new_database(cur, database_engine, data_stores=data_stores) # check if any of our configured dynamic modules want a database if config is not None: @@ -84,9 +103,10 @@ def prepare_database(db_conn, database_engine, config): raise -def _setup_new_database(cur, database_engine): +def _setup_new_database(cur, database_engine, data_stores): """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. + applying any necessary deltas, including schemas from the given data + stores. The "full_schemas" directory has subdirectories named after versions. This function searches for the highest version less than or equal to @@ -111,52 +131,83 @@ def _setup_new_database(cur, database_engine): In the example foo.sql and bar.sql would be run, and then any delta files for versions strictly greater than 11. + + Note: we apply the full schemas and deltas from the top level `schema/` + folder as well those in the data stores specified. + + Args: + cur (Cursor): a database cursor + database_engine (DatabaseEngine) + data_stores (list[str]): The names of the data stores to instantiate + on the given database. """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") + # We're about to set up a brand new database so we check that its + # configured to our liking. + database_engine.check_new_database(cur) - if isinstance(database_engine, PostgresEngine): - specific = "postgres" - else: - specific = "sqlite" + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) - specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$") + # First we find the highest full schema version we have + valid_versions = [] for filename in directory_entries: - match = pattern.match(filename) or specific_pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.debug("Ignoring entry '%s' in 'full_schemas'", filename) + try: + ver = int(filename) + except ValueError: + continue + + if ver <= SCHEMA_VERSION: + valid_versions.append(ver) - if not valid_dirs: + if not valid_versions: raise PrepareDatabaseException( "Could not find a suitable base set of full schemas" ) - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + max_current_ver = max(valid_versions) logger.debug("Initialising schema v%d", max_current_ver) - directory_entries = os.listdir(sql_dir) + # Now lets find all the full schema files, both in the global schema and + # in data store schemas. + directories = [os.path.join(current_dir, str(max_current_ver))] + directories.extend( + os.path.join( + dir_path, + "data_stores", + data_store, + "schema", + "full_schemas", + str(max_current_ver), + ) + for data_store in data_stores + ) + + directory_entries = [] + for directory in directories: + directory_entries.extend( + _DirectoryListing(file_name, os.path.join(directory, file_name)) + for file_name in os.listdir(directory) + ) + + if isinstance(database_engine, PostgresEngine): + specific = "postgres" + else: + specific = "sqlite" - for filename in sorted( - fnmatch.filter(directory_entries, "*.sql") - + fnmatch.filter(directory_entries, "*.sql." + specific) - ): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) + directory_entries.sort() + for entry in directory_entries: + if entry.file_name.endswith(".sql") or entry.file_name.endswith( + ".sql." + specific + ): + logger.debug("Applying schema %s", entry.absolute_path) + executescript(cur, entry.absolute_path) cur.execute( database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)" + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" ), (max_current_ver, False), ) @@ -168,6 +219,7 @@ def _setup_new_database(cur, database_engine): upgraded=False, database_engine=database_engine, config=None, + data_stores=data_stores, is_empty=True, ) @@ -179,6 +231,7 @@ def _upgrade_existing_database( upgraded, database_engine, config, + data_stores, is_empty=False, ): """Upgrades an existing database. @@ -215,6 +268,10 @@ def _upgrade_existing_database( only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in some arbitrary order. + Note: we apply the delta files from the specified data stores as well as + those in the top-level schema. We apply all delta files across data stores + for a version before applying those in the next version. + Args: cur (Cursor) current_version (int): The current version of the schema. @@ -224,7 +281,19 @@ def _upgrade_existing_database( applied deltas or from full schema file. If `True` the function will never apply delta files for the given `current_version`, since the current_version wasn't generated by applying those delta files. + database_engine (DatabaseEngine) + config (synapse.config.homeserver.HomeServerConfig|None): + None if we are initialising a blank database, otherwise the application + config + data_stores (list[str]): The names of the data stores to instantiate + on the given database. + is_empty (bool): Is this a blank database? I.e. do we need to run the + upgrade portions of the delta scripts. """ + if is_empty: + assert not applied_delta_files + else: + assert config if current_version > SCHEMA_VERSION: raise ValueError( @@ -232,6 +301,13 @@ def _upgrade_existing_database( + "new for the server to understand" ) + # some of the deltas assume that config.server_name is set correctly, so now + # is a good time to run the sanity check. + if not is_empty and "main" in data_stores: + from synapse.storage.data_stores.main import check_database_before_upgrade + + check_database_before_upgrade(cur, database_engine, config) + start_ver = current_version if not upgraded: start_ver += 1 @@ -248,24 +324,65 @@ def _upgrade_existing_database( for v in range(start_ver, SCHEMA_VERSION + 1): logger.info("Upgrading schema to v%d", v) + # We need to search both the global and per data store schema + # directories for schema updates. + + # First we find the directories to search in delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + directories = [delta_dir] + for data_store in data_stores: + directories.append( + os.path.join( + dir_path, "data_stores", data_store, "schema", "delta", str(v) + ) + ) - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) + # Used to check if we have any duplicate file names + file_name_counter = Counter() + + # Now find which directories have anything of interest. + directory_entries = [] + for directory in directories: + logger.debug("Looking for schema deltas in %s", directory) + try: + file_names = os.listdir(directory) + directory_entries.extend( + _DirectoryListing(file_name, os.path.join(directory, file_name)) + for file_name in file_names + ) + + for file_name in file_names: + file_name_counter[file_name] += 1 + except FileNotFoundError: + # Data stores can have empty entries for a given version delta. + pass + except OSError: + raise UpgradeDatabaseException( + "Could not open delta dir for version %d: %s" % (v, directory) + ) + + duplicates = { + file_name for file_name, count in file_name_counter.items() if count > 1 + } + if duplicates: + # We don't support using the same file name in the same delta version. + raise PrepareDatabaseException( + "Found multiple delta files with the same name in v%d: %s" + % (v, duplicates,) ) + # We sort to ensure that we apply the delta files in a consistent + # order (to avoid bugs caused by inconsistent directory listing order) directory_entries.sort() - for file_name in directory_entries: + for entry in directory_entries: + file_name = entry.file_name relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) + absolute_path = entry.absolute_path + + logger.debug("Found file: %s (%s)", relative_path, absolute_path) if relative_path in applied_delta_files: continue - absolute_path = os.path.join(dir_path, "schema", "delta", relative_path) root_name, ext = os.path.splitext(file_name) if ext == ".py": # This is a python upgrade module. We need to import into some @@ -352,7 +469,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ), (modname,), ) - applied_deltas = set(d for d, in cur) + applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: if name in applied_deltas: continue @@ -364,13 +481,12 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ) logger.info("applying schema %s for %s", name, modname) - for statement in get_statements(stream): - cur.execute(statement) + execute_statements_from_stream(cur, stream) # Mark as done. cur.execute( database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)" + "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)" ), (modname, name), ) @@ -423,8 +539,12 @@ def get_statements(f): def executescript(txn, schema_path): with open(schema_path, "r") as f: - for statement in get_statements(f): - txn.execute(statement) + execute_statements_from_stream(txn, f) + + +def execute_statements_from_stream(cur: Cursor, f: TextIO): + for statement in get_statements(f): + cur.execute(statement) def _get_or_create_schema_state(txn, database_engine): @@ -448,3 +568,16 @@ def _get_or_create_schema_state(txn, database_engine): return current_version, applied_deltas, upgraded return None + + +@attr.s() +class _DirectoryListing(object): + """Helper class to store schema file name and the + absolute path to it. + + These entries get sorted, so for consistency we want to ensure that + `file_name` attr is kept first. + """ + + file_name = attr.ib() + absolute_path = attr.ib() diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 5db6f2d84a..18a462f0ee 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -15,13 +15,7 @@ from collections import namedtuple -from twisted.internet import defer - from synapse.api.constants import PresenceState -from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedList - -from ._base import SQLBaseStore class UserPresenceState( @@ -73,133 +67,3 @@ class UserPresenceState( status_msg=None, currently_active=False, ) - - -class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( - len(presence_states) - ) - - with stream_ordering_manager as stream_orderings: - yield self.runInteraction( - "update_presence", - self._update_presence_txn, - stream_orderings, - presence_states, - ) - - return stream_orderings[-1], self._presence_id_gen.get_current_token() - - def _update_presence_txn(self, txn, stream_orderings, presence_states): - for stream_id, state in zip(stream_orderings, presence_states): - txn.call_after( - self.presence_stream_cache.entity_has_changed, state.user_id, stream_id - ) - txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) - - # Actually insert new rows - self._simple_insert_many_txn( - txn, - table="presence_stream", - values=[ - { - "stream_id": stream_id, - "user_id": state.user_id, - "state": state.state, - "last_active_ts": state.last_active_ts, - "last_federation_update_ts": state.last_federation_update_ts, - "last_user_sync_ts": state.last_user_sync_ts, - "status_msg": state.status_msg, - "currently_active": state.currently_active, - } - for state in presence_states - ], - ) - - # Delete old rows to stop database from getting really big - sql = ( - "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)" - ) - - for states in batch_iter(presence_states, 50): - args = [stream_id] - args.extend(s.user_id for s in states) - txn.execute(sql % (",".join("?" for _ in states),), args) - - def get_all_presence_updates(self, last_id, current_id): - if last_id == current_id: - return defer.succeed([]) - - def get_all_presence_updates_txn(txn): - sql = ( - "SELECT stream_id, user_id, state, last_active_ts," - " last_federation_update_ts, last_user_sync_ts, status_msg," - " currently_active" - " FROM presence_stream" - " WHERE ? < stream_id AND stream_id <= ?" - ) - txn.execute(sql, (last_id, current_id)) - return txn.fetchall() - - return self.runInteraction( - "get_all_presence_updates", get_all_presence_updates_txn - ) - - @cached() - def _get_presence_for_user(self, user_id): - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_presence_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( - table="presence_stream", - column="user_id", - iterable=user_ids, - keyvalues={}, - retcols=( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - desc="get_presence_for_users", - ) - - for row in rows: - row["currently_active"] = bool(row["currently_active"]) - - return {row["user_id"]: UserPresenceState(**row) for row in rows} - - def get_current_presence_token(self): - return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py new file mode 100644 index 0000000000..fdc0abf5cf --- /dev/null +++ b/synapse/storage/purge_events.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +from twisted.internet import defer + +logger = logging.getLogger(__name__) + + +class PurgeEventsStorage(object): + """High level interface for purging rooms and event history. + """ + + def __init__(self, hs, stores): + self.stores = stores + + @defer.inlineCallbacks + def purge_room(self, room_id: str): + """Deletes all record of a room + """ + + state_groups_to_delete = yield self.stores.main.purge_room(room_id) + yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) + + @defer.inlineCallbacks + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + """ + state_groups = yield self.stores.main.purge_history( + room_id, token, delete_local_events + ) + + logger.info("[purge] finding state groups that can be deleted") + + sg_to_delete = yield self._find_unreferenced_groups(state_groups) + + yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + + @defer.inlineCallbacks + def _find_unreferenced_groups(self, state_groups): + """Used when purging history to figure out which state groups can be + deleted. + + Args: + state_groups (set[int]): Set of state groups referenced by events + that are going to be deleted. + + Returns: + Deferred[set[int]] The set of state groups that can be deleted. + """ + # Graph of state group -> previous group + graph = {} + + # Set of events that we have found to be referenced by events + referenced_groups = set() + + # Set of state groups we've already seen + state_groups_seen = set(state_groups) + + # Set of state groups to handle next. + next_to_search = set(state_groups) + while next_to_search: + # We bound size of groups we're looking up at once, to stop the + # SQL query getting too big + if len(next_to_search) < 100: + current_search = next_to_search + next_to_search = set() + else: + current_search = set(itertools.islice(next_to_search, 100)) + next_to_search -= current_search + + referenced = yield self.stores.main.get_referenced_state_groups( + current_search + ) + referenced_groups |= referenced + + # We don't continue iterating up the state group graphs for state + # groups that are referenced. + current_search -= referenced + + edges = yield self.stores.state.get_previous_state_groups(current_search) + + prevs = set(edges.values()) + # We don't bother re-handling groups we've already seen + prevs -= state_groups_seen + next_to_search |= prevs + state_groups_seen |= prevs + + graph.update(edges) + + to_delete = state_groups_seen - referenced_groups + + return to_delete diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index a6517c4cf3..f47cec0d86 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -14,708 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -import logging - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.push.baserules import list_with_base_rules -from synapse.storage.appservice import ApplicationServiceWorkerStore -from synapse.storage.pusher import PusherWorkerStore -from synapse.storage.receipts import ReceiptsWorkerStore -from synapse.storage.roommember import RoomMemberWorkerStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList -from synapse.util.caches.stream_change_cache import StreamChangeCache - -from ._base import SQLBaseStore - -logger = logging.getLogger(__name__) - - -def _load_rules(rawrules, enabled_map): - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) - - for i, rule in enumerate(rules): - rule_id = rule["rule_id"] - if rule_id in enabled_map: - if rule.get("enabled", True) != bool(enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule["enabled"] = bool(enabled_map[rule_id]) - rules[i] = rule - - return rules - - -class PushRulesWorkerStore( - ApplicationServiceWorkerStore, - ReceiptsWorkerStore, - PusherWorkerStore, - RoomMemberWorkerStore, - SQLBaseStore, -): - """This is an abstract base class where subclasses must implement - `get_max_push_rules_stream_id` which can be called in the initializer. - """ - - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) - - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, - "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self.get_max_push_rules_stream_id(), - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - push_rules_id, - prefilled_cache=push_rules_prefill, - ) - - @abc.abstractmethod - def get_max_push_rules_stream_id(self): - """Get the position of the push rules stream. - - Returns: - int - """ - raise NotImplementedError() - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( - table="push_rules", - keyvalues={"user_name": user_id}, - retcols=( - "user_name", - "rule_id", - "priority_class", - "priority", - "conditions", - "actions", - ), - desc="get_push_rules_enabled_for_user", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - enabled_map = yield self.get_push_rules_enabled_for_user(user_id) - - rules = _load_rules(rows, enabled_map) - - return rules - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( - table="push_rules_enable", - keyvalues={"user_name": user_id}, - retcols=("user_name", "rule_id", "enabled"), - desc="get_push_rules_enabled_for_user", - ) - return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) - - @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules(self, user_ids): - if not user_ids: - return {} - - results = {user_id: [] for user_id in user_ids} - - rows = yield self._simple_select_many_batch( - table="push_rules", - column="user_name", - iterable=user_ids, - retcols=("*",), - desc="bulk_get_push_rules", - ) - - rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - - for row in rows: - results.setdefault(row["user_name"], []).append(row) - - enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) - - for user_id, rules in results.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) - - return results - - @defer.inlineCallbacks - def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): - """Move a single push rule from one room to another for a specific user. - - Args: - new_room_id (str): ID of the new room. - user_id (str): ID of user the push rule belongs to. - rule (Dict): A push rule. - """ - # Create new rule id - rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) - new_rule_id = rule_id_scope + "/" + new_room_id - - # Change room id in each condition - for condition in rule.get("conditions", []): - if condition.get("key") == "room_id": - condition["pattern"] = new_room_id - - # Add the rule for the new room - yield self.add_push_rule( - user_id=user_id, - rule_id=new_rule_id, - priority_class=rule["priority_class"], - conditions=rule["conditions"], - actions=rule["actions"], - ) - - # Delete push rule for the old room - yield self.delete_push_rule(user_id, rule["rule_id"]) - - @defer.inlineCallbacks - def move_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id - ): - """Move all of the push rules from one room to another for a specific - user. - - Args: - old_room_id (str): ID of the old room. - new_room_id (str): ID of the new room. - user_id (str): ID of user to copy push rules for. - """ - # Retrieve push rules for this user - user_push_rules = yield self.get_push_rules_for_user(user_id) - - # Get rules relating to the old room, move them to the new room, then - # delete them from the old room - for rule in user_push_rules: - conditions = rule.get("conditions", []) - if any( - (c.get("key") == "room_id" and c.get("pattern") == old_room_id) - for c in conditions - ): - self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) - - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield context.get_current_state_ids(self) - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = set( - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - ) - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = set( - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user - - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, - ) - def bulk_get_push_rules_enabled(self, user_ids): - if not user_ids: - return {} - - results = {user_id: {} for user_id in user_ids} - - rows = yield self._simple_select_many_batch( - table="push_rules_enable", - column="user_name", - iterable=user_ids, - retcols=("user_name", "rule_id", "enabled"), - desc="bulk_get_push_rules_enabled", - ) - for row in rows: - enabled = bool(row["enabled"]) - results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled - return results - - -class PushRuleStore(PushRulesWorkerStore): - @defer.inlineCallbacks - def add_push_rule( - self, - user_id, - rule_id, - priority_class, - conditions, - actions, - before=None, - after=None, - ): - conditions_json = json.dumps(conditions) - actions_json = json.dumps(actions) - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - if before or after: - yield self.runInteraction( - "_add_push_rule_relative_txn", - self._add_push_rule_relative_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ) - else: - yield self.runInteraction( - "_add_push_rule_highest_priority_txn", - self._add_push_rule_highest_priority_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ) - - def _add_push_rule_relative_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - before, - after, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - relative_to_rule = before or after - - res = self._simple_select_one_txn( - txn, - table="push_rules", - keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, - retcols=["priority_class", "priority"], - allow_none=True, - ) - - if not res: - raise RuleNotFoundException( - "before/after rule not found: %s" % (relative_to_rule,) - ) - - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] - - if base_priority_class != priority_class: - raise InconsistentRuleException( - "Given priority class does not match class of relative rule" - ) - - if before: - # Higher priority rules are executed first, So adding a rule before - # a rule means giving it a higher priority than that rule. - new_rule_priority = base_rule_priority + 1 - else: - # We increment the priority of the existing rules to make space for - # the new rule. Therefore if we want this rule to appear after - # an existing rule we give it the priority of the existing rule, - # and then increment the priority of the existing rule. - new_rule_priority = base_rule_priority - - sql = ( - "UPDATE push_rules SET priority = priority + 1" - " WHERE user_name = ? AND priority_class = ? AND priority >= ?" - ) - - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_rule_priority, - conditions_json, - actions_json, - ) - - def _add_push_rule_highest_priority_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - conditions_json, - actions_json, - ): - # Lock the table since otherwise we'll have annoying races between the - # SELECT here and the UPSERT below. - self.database_engine.lock_table(txn, "push_rules") - - # find the highest priority rule in that class - sql = ( - "SELECT COUNT(*), MAX(priority) FROM push_rules" - " WHERE user_name = ? and priority_class = ?" - ) - txn.execute(sql, (user_id, priority_class)) - res = txn.fetchall() - (how_many, highest_prio) = res[0] - - new_prio = 0 - if how_many > 0: - new_prio = highest_prio + 1 - - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - new_prio, - conditions_json, - actions_json, - ) - - def _upsert_push_rule_txn( - self, - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - conditions_json, - actions_json, - update_stream=True, - ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id - using the _push_rule_id_gen if it needs to insert the rule. It assumes - that the "push_rules" table is locked""" - - sql = ( - "UPDATE push_rules" - " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" - " WHERE user_name = ? AND rule_id = ?" - ) - - txn.execute( - sql, - (priority_class, priority, conditions_json, actions_json, user_id, rule_id), - ) - - if txn.rowcount == 0: - # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next() - - self._simple_insert_txn( - txn, - table="push_rules", - values={ - "id": push_rule_id, - "user_name": user_id, - "rule_id": rule_id, - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - if update_stream: - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ADD", - data={ - "priority_class": priority_class, - "priority": priority, - "conditions": conditions_json, - "actions": actions_json, - }, - ) - - @defer.inlineCallbacks - def delete_push_rule(self, user_id, rule_id): - """ - Delete a push rule. Args specify the row to be deleted and can be - any of the columns in the push_rule table, but below are the - standard ones - - Args: - user_id (str): The matrix ID of the push rule owner - rule_id (str): The rule_id of the rule to be deleted - """ - - def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( - txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} - ) - - self._insert_push_rules_update_txn( - txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "delete_push_rule", - delete_push_rule_txn, - stream_id, - event_stream_ordering, - ) - - @defer.inlineCallbacks - def set_push_rule_enabled(self, user_id, rule_id, enabled): - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "_set_push_rule_enabled_txn", - self._set_push_rule_enabled_txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - enabled, - ) - - def _set_push_rule_enabled_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled - ): - new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( - txn, - "push_rules_enable", - {"user_name": user_id, "rule_id": rule_id}, - {"enabled": 1 if enabled else 0}, - {"id": new_id}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ENABLE" if enabled else "DISABLE", - ) - - @defer.inlineCallbacks - def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): - actions_json = json.dumps(actions) - - def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): - if is_default_rule: - # Add a dummy rule to the rules table with the user specified - # actions. - priority_class = -1 - priority = 1 - self._upsert_push_rule_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - priority_class, - priority, - "[]", - actions_json, - update_stream=False, - ) - else: - self._simple_update_one_txn( - txn, - "push_rules", - {"user_name": user_id, "rule_id": rule_id}, - {"actions": actions_json}, - ) - - self._insert_push_rules_update_txn( - txn, - stream_id, - event_stream_ordering, - user_id, - rule_id, - op="ACTIONS", - data={"actions": actions_json}, - ) - - with self._push_rules_stream_id_gen.get_next() as ids: - stream_id, event_stream_ordering = ids - yield self.runInteraction( - "set_push_rule_actions", - set_push_rule_actions_txn, - stream_id, - event_stream_ordering, - ) - - def _insert_push_rules_update_txn( - self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None - ): - values = { - "stream_id": stream_id, - "event_stream_ordering": event_stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": op, - } - if data is not None: - values.update(data) - - self._simple_insert_txn(txn, "push_rules_stream", values=values) - - txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) - txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) - txn.call_after( - self.push_rules_stream_cache.entity_has_changed, user_id, stream_id - ) - - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" - if last_id == current_id: - return defer.succeed([]) - - def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return self.runInteraction( - "get_all_push_rule_updates", get_all_push_rule_updates_txn - ) - - def get_push_rules_stream_token(self): - """Get the position of the push rules stream. - Returns a pair of a stream id for the push_rules stream and the - room stream ordering it corresponds to.""" - return self._push_rules_stream_id_gen.get_current_token() - - def get_max_push_rules_stream_id(self): - return self.get_push_rules_stream_token()[0] - class RuleNotFoundException(Exception): pass diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index fcb5f2f23a..d471ec9860 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -17,11 +17,7 @@ import logging import attr -from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore -from synapse.storage.stream import generate_pagination_where_clause -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -113,358 +109,3 @@ class AggregationPaginationToken(object): def as_tuple(self): return attr.astuple(self) - - -class RelationsWorkerStore(SQLBaseStore): - @cached(tree=True) - def get_relations_for_event( - self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of relations for an event, ordered by topological ordering. - - Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. - - Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. - """ - - where_clause = ["relates_to_id = ?"] - where_args = [event_id] - - if relation_type is not None: - where_clause.append("relation_type = ?") - where_args.append(relation_type) - - if event_type is not None: - where_clause.append("type = ?") - where_args.append(event_type) - - if aggregation_key: - where_clause.append("aggregation_key = ?") - where_args.append(aggregation_key) - - pagination_clause = generate_pagination_where_clause( - direction=direction, - column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if pagination_clause: - where_clause.append(pagination_clause) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - sql = """ - SELECT event_id, topological_ordering, stream_ordering - FROM event_relations - INNER JOIN events USING (event_id) - WHERE %s - ORDER BY topological_ordering %s, stream_ordering %s - LIMIT ? - """ % ( - " AND ".join(where_clause), - order, - order, - ) - - def _get_recent_references_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - last_topo_id = None - last_stream_id = None - events = [] - for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] - - next_batch = None - if len(events) > limit and last_topo_id and last_stream_id: - next_batch = RelationPaginationToken(last_topo_id, last_stream_id) - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.runInteraction( - "get_recent_references_for_event", _get_recent_references_for_event_txn - ) - - @cached(tree=True) - def get_aggregation_groups_for_event( - self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): - """Get a list of annotations on the event, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happend - on an event. - - Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - - - Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. - """ - - where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args = [event_id, RelationTypes.ANNOTATION] - - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) - FROM event_relations - INNER JOIN events USING (event_id) - WHERE {where_clause} - GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} - LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) - - def _get_aggregation_groups_for_event_txn(txn): - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) - - if len(events) <= limit: - next_batch = None - - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) - - return self.runInteraction( - "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn - ) - - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): - """Get the most recent edit (if any) that has happened for the given - event. - - Correctly handles checking whether edits were allowed to happen. - - Args: - event_id (str): The original event ID - - Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. - """ - - # We only allow edits for `m.room.message` events that have the same sender - # and event type. We can't assert these things during regular event auth so - # we have to do the checks post hoc. - - # Fetches latest edit that has the same type and sender as the - # original, and is an `m.room.message`. - sql = """ - SELECT edit.event_id FROM events AS edit - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS original ON - original.event_id = relates_to_id - AND edit.type = original.type - AND edit.sender = original.sender - WHERE - relates_to_id = ? - AND relation_type = ? - AND edit.type = 'm.room.message' - ORDER by edit.origin_server_ts DESC, edit.event_id DESC - LIMIT 1 - """ - - def _get_applicable_edit_txn(txn): - txn.execute(sql, (event_id, RelationTypes.REPLACE)) - row = txn.fetchone() - if row: - return row[0] - - edit_id = yield self.runInteraction( - "get_applicable_edit", _get_applicable_edit_txn - ) - - if not edit_id: - return - - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event - - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): - """Check if a user has already annotated an event with the same key - (e.g. already liked an event). - - Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation - - Returns: - Deferred[bool] - """ - - sql = """ - SELECT 1 FROM event_relations - INNER JOIN events USING (event_id) - WHERE - relates_to_id = ? - AND relation_type = ? - AND type = ? - AND sender = ? - AND aggregation_key = ? - LIMIT 1; - """ - - def _get_if_user_has_annotated_event(txn): - txn.execute( - sql, - ( - parent_id, - RelationTypes.ANNOTATION, - event_type, - sender, - aggregation_key, - ), - ) - - return bool(txn.fetchone()) - - return self.runInteraction( - "get_if_user_has_annotated_event", _get_if_user_has_annotated_event - ) - - -class RelationsStore(RelationsWorkerStore): - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events - - Args: - txn - event (EventBase) - """ - relation = event.content.get("m.relates_to") - if not relation: - # No relations - return - - rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - ): - # Unknown relation type - return - - parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation - return - - aggregation_key = relation.get("key") - - self._simple_insert_txn( - txn, - table="event_relations", - values={ - "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, - }, - ) - - txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) - txn.call_after( - self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) - ) - - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) - - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. - - Args: - txn - redacted_event_id (str): The event that was redacted. - """ - - self._simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py deleted file mode 100644 index 08e13f3a3b..0000000000 --- a/synapse/storage/room.py +++ /dev/null @@ -1,585 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import logging -import re - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.storage._base import SQLBaseStore -from synapse.storage.search import SearchStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks - -logger = logging.getLogger(__name__) - - -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) - - -class RoomWorkerStore(SQLBaseStore): - def get_room(self, room_id): - """Retrieve a room. - - Args: - room_id (str): The ID of the room to retrieve. - Returns: - A dict containing the room information, or None if the room is unknown. - """ - return self._simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator"), - desc="get_room", - allow_none=True, - ) - - def get_public_room_ids(self): - return self._simple_select_onecol( - table="rooms", - keyvalues={"is_public": True}, - retcol="room_id", - desc="get_public_room_ids", - ) - - @cached(num_args=2, max_entries=100) - def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): - """Get pulbic rooms for a particular list, or across all lists. - - Args: - stream_id (int) - network_tuple (ThirdPartyInstanceID): The list to use (None, None) - means the main list, None means all lsits. - """ - return self.runInteraction( - "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, - stream_id, - network_tuple=network_tuple, - ) - - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple): - return { - rm - for rm, vis in self.get_published_at_stream_id_txn( - txn, stream_id, network_tuple=network_tuple - ).items() - if vis - } - - def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): - if network_tuple: - # We want to get from a particular list. No aggregation required. - - sql = """ - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id - FROM public_room_list_stream - WHERE stream_id <= ? %s - GROUP BY room_id - ) grouped USING (room_id, stream_id) - """ - - if network_tuple.appservice_id is not None: - txn.execute( - sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id), - ) - else: - txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,)) - return dict(txn) - else: - # We want to get from all lists, so we need to aggregate the results - - logger.info("Executing full list") - - sql = """ - SELECT room_id, visibility - FROM public_room_list_stream - INNER JOIN ( - SELECT - room_id, max(stream_id) AS stream_id, appservice_id, - network_id - FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id, appservice_id, network_id - ) grouped USING (room_id, stream_id) - """ - - txn.execute(sql, (stream_id,)) - - results = {} - # A room is visible if its visible on any list. - for room_id, visibility in txn: - results[room_id] = bool(visibility) or results.get(room_id, False) - - return results - - def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple): - def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn( - txn, prev_stream_id, network_tuple - ) - - now_rooms_dict = self.get_published_at_stream_id_txn( - txn, new_stream_id, network_tuple - ) - - now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis) - now_rooms_not_visible = set( - rm for rm, vis in now_rooms_dict.items() if not vis - ) - - newly_visible = now_rooms_visible - then_rooms - newly_unpublished = now_rooms_not_visible & then_rooms - - return newly_visible, newly_unpublished - - return self.runInteraction( - "get_public_room_changes", get_public_room_changes_txn - ) - - @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - retcol="1", - allow_none=True, - desc="is_room_blocked", - ) - - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - return RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) - else: - return None - - -class RoomStore(RoomWorkerStore, SearchStore): - @defer.inlineCallbacks - def store_room(self, room_id, room_creator_user_id, is_public): - """Stores a room. - - Args: - room_id (str): The desired room ID, can be None. - room_creator_user_id (str): The user ID of the room creator. - is_public (bool): True to indicate that this room should appear in - public room lists. - Raises: - StoreError if the room could not be stored. - """ - try: - - def store_room_txn(txn, next_id): - self._simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - }, - ) - if is_public: - self._simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction("store_room_txn", store_room_txn, next_id) - except Exception as e: - logger.error("store_room with room_id=%s failed: %s", room_id, e) - raise StoreError(500, "Problem creating room.") - - @defer.inlineCallbacks - def set_room_is_public(self, room_id, is_public): - def set_room_is_public_txn(txn, next_id): - self._simple_update_one_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - ) - - entries = self._simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": None, - "network_id": None, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self._simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": None, - "network_id": None, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( - "set_room_is_public", set_room_is_public_txn, next_id - ) - self.hs.get_notifier().on_new_replication_data() - - @defer.inlineCallbacks - def set_room_is_public_appservice( - self, room_id, appservice_id, network_id, is_public - ): - """Edit the appservice/network specific public room list. - - Each appservice can have a number of published room lists associated - with them, keyed off of an appservice defined `network_id`, which - basically represents a single instance of a bridge to a third party - network. - - Args: - room_id (str) - appservice_id (str) - network_id (str) - is_public (bool): Whether to publish or unpublish the room from the - list. - """ - - def set_room_is_public_appservice_txn(txn, next_id): - if is_public: - try: - self._simple_insert_txn( - txn, - table="appservice_room_list", - values={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - except self.database_engine.module.IntegrityError: - # We've already inserted, nothing to do. - return - else: - self._simple_delete_txn( - txn, - table="appservice_room_list", - keyvalues={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - - entries = self._simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": appservice_id, - "network_id": network_id, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self._simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": appservice_id, - "network_id": network_id, - }, - ) - - with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( - "set_room_is_public_appservice", - set_room_is_public_appservice_txn, - next_id, - ) - self.hs.get_notifier().on_new_replication_data() - - def get_room_count(self): - """Retrieve a list of all rooms - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return self.runInteraction("get_rooms", f) - - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: - self.store_event_search_txn( - txn, event, "content.topic", event.content["topic"] - ) - - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: - self.store_event_search_txn( - txn, event, "content.name", event.content["name"] - ) - - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: - self.store_event_search_txn( - txn, event, "content.body", event.content["body"] - ) - - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): - next_id = self._event_reports_id_gen.get_next() - return self._simple_insert( - table="event_reports", - values={ - "id": next_id, - "received_ts": received_ts, - "room_id": room_id, - "event_id": event_id, - "user_id": user_id, - "reason": reason, - "content": json.dumps(content), - }, - desc="add_event_report", - ) - - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) - - @defer.inlineCallbacks - def block_room(self, room_id, user_id): - """Marks the room as blocked. Can be called multiple times. - - Args: - room_id (str): Room to block - user_id (str): Who blocked it - - Returns: - Deferred - """ - yield self._simple_upsert( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={"user_id": user_id}, - desc="block_room", - ) - yield self.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - - def get_media_mxcs_in_room(self, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - - def _get_media_mxcs_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - local_media_mxcs = [] - remote_media_mxcs = [] - - # Convert the IDs to MXC URIs - for media_id in local_mxcs: - local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) - for hostname, media_id in remote_mxcs: - remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) - - return local_media_mxcs, remote_media_mxcs - - return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) - - def quarantine_media_ids_in_room(self, room_id, quarantined_by): - """For a room loops through all events with media and quarantines - the associated media - """ - - def _quarantine_media_in_room_txn(txn): - local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) - total_media_quarantined = 0 - - # Now update all the tables to set the quarantined_by flag - - txn.executemany( - """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """, - ((quarantined_by, media_id) for media_id in local_mxcs), - ) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_mxcs - ), - ) - - total_media_quarantined += len(local_mxcs) - total_media_quarantined += len(remote_mxcs) - - return total_media_quarantined - - return self.runInteraction( - "quarantine_media_in_room", _quarantine_media_in_room_txn - ) - - def _get_media_mxcs_in_room_txn(self, txn, room_id): - """Retrieves all the local and remote media MXC URIs in a given room - - Args: - txn (cursor) - room_id (str) - - Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. - """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - - next_token = self.get_current_events_token() + 1 - local_media_mxcs = [] - remote_media_mxcs = [] - - while next_token: - sql = """ - SELECT stream_ordering, json FROM events - JOIN event_json USING (room_id, event_id) - WHERE room_id = ? - AND stream_ordering < ? - AND contains_url = ? AND outlier = ? - ORDER BY stream_ordering DESC - LIMIT ? - """ - txn.execute(sql, (room_id, next_token, True, False, 100)) - - next_token = None - for stream_ordering, content_json in txn: - next_token = stream_ordering - event_json = json.loads(content_json) - content = event_json["content"] - content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") - - for url in (content_url, thumbnail_url): - if not url: - continue - matches = mxc_re.match(url) - if matches: - hostname = matches.group(1) - media_id = matches.group(2) - if hostname == self.hs.hostname: - local_media_mxcs.append(media_id) - else: - remote_media_mxcs.append((hostname, media_id)) - - return local_media_mxcs, remote_media_mxcs diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4df8ebdacd..8c4a83a840 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,24 +17,6 @@ import logging from collections import namedtuple -from six import iteritems, itervalues - -from canonicaljson import json - -from twisted.internet import defer - -from synapse.api.constants import EventTypes, Membership -from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction -from synapse.storage.engines import Sqlite3Engine -from synapse.storage.events_worker import EventsWorkerStore -from synapse.types import get_domain_from_id -from synapse.util.async_helpers import Linearizer -from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.util.stringutils import to_ascii - logger = logging.getLogger(__name__) @@ -55,1057 +37,3 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) # a given membership type, suitable for use in calculating heroes for a room. # "count" points to the total numberr of users of a given membership type. MemberSummary = namedtuple("MemberSummary", ("members", "count")) - -_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" -_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" - - -class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) - - # Is the current_state_events.membership up to date? Or is the - # background update still running? - self._current_state_events_membership_up_to_date = False - - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, - ) - self._check_safe_current_state_events_membership_updated_txn(txn) - txn.close() - - if self.hs.config.metrics_flags.known_servers: - self._known_servers_count = 1 - self.hs.get_clock().looping_call( - run_as_background_process, - 60 * 1000, - "_count_known_servers", - self._count_known_servers, - ) - self.hs.get_clock().call_later( - 1000, - run_as_background_process, - "_count_known_servers", - self._count_known_servers, - ) - LaterGauge( - "synapse_federation_known_servers", - "", - [], - lambda: self._known_servers_count, - ) - - @defer.inlineCallbacks - def _count_known_servers(self): - """ - Count the servers that this server knows about. - - The statistic is stored on the class for the - `synapse_federation_known_servers` LaterGauge to collect. - """ - - def _transact(txn): - if isinstance(self.database_engine, Sqlite3Engine): - query = """ - SELECT COUNT(DISTINCT substr(out.user_id, pos+1)) - FROM ( - SELECT rm.user_id as user_id, instr(rm.user_id, ':') - AS pos FROM room_memberships as rm - INNER JOIN current_state_events as c ON rm.event_id = c.event_id - WHERE c.type = 'm.room.member' - ) as out - """ - else: - query = """ - SELECT COUNT(DISTINCT split_part(state_key, ':', 2)) - FROM current_state_events - WHERE type = 'm.room.member' AND membership = 'join'; - """ - txn.execute(query) - return list(txn)[0][0] - - count = yield self.runInteraction("get_known_servers", _transact) - - # We always know about ourselves, even if we have nothing in - # room_memberships (for example, the server is new). - self._known_servers_count = max([count, 1]) - return self._known_servers_count - - def _check_safe_current_state_events_membership_updated_txn(self, txn): - """Checks if it is safe to assume the new current_state_events - membership column is up to date - """ - - pending_update = self._simple_select_one_txn( - txn, - table="background_updates", - keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, - retcols=["update_name"], - allow_none=True, - ) - - self._current_state_events_membership_up_to_date = not pending_update - - # If the update is still running, reschedule to run. - if pending_update: - self._clock.call_later( - 15.0, - run_as_background_process, - "_check_safe_current_state_events_membership_updated", - self.runInteraction, - "_check_safe_current_state_events_membership_updated", - self._check_safe_current_state_events_membership_updated_txn, - ) - - @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) - def get_hosts_in_room(self, room_id, cache_context): - """Returns the set of all hosts currently in the room - """ - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) - return hosts - - @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): - return self.runInteraction( - "get_users_in_room", self.get_users_in_room_txn, room_id - ) - - def get_users_in_room_txn(self, txn, room_id): - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? - """ - else: - sql = """ - SELECT state_key FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? - """ - - txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] - - @cached(max_entries=100000) - def get_room_summary(self, room_id): - """ Get the details of a room roughly suitable for use by the room - summary extension to /sync. Useful when lazy loading room members. - Args: - room_id (str): The room ID to query - Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. - """ - - def _get_room_summary_txn(txn): - # first get counts. - # We do this all in one transaction to keep the cache small. - # FIXME: get rid of this when we have room_stats - - # If we can assume current_state_events.membership is up to date - # then we can avoid a join, which is a Very Good Thing given how - # frequently this function gets called. - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ - else: - sql = """ - SELECT count(*), m.membership FROM room_memberships as m - INNER JOIN current_state_events as c - ON m.event_id = c.event_id - AND m.room_id = c.room_id - AND m.user_id = c.state_key - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - res = {} - for count, membership in txn: - summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) - - # we order by membership and then fairly arbitrarily by event_id so - # heroes are consistent - if self._current_state_events_membership_up_to_date: - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT state_key, membership, event_id - FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - ORDER BY - CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - event_id ASC - LIMIT ? - """ - else: - sql = """ - SELECT c.state_key, m.membership, c.event_id - FROM room_memberships as m - INNER JOIN current_state_events as c USING (room_id, event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - ORDER BY - CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, - c.event_id ASC - LIMIT ? - """ - - # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. - txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) - for user_id, membership, event_id in txn: - summary = res[to_ascii(membership)] - # we will always have a summary for this membership type at this - # point given the summary currently contains the counts. - members = summary.members - members.append((to_ascii(user_id), to_ascii(event_id))) - - return res - - return self.runInteraction("get_room_summary", _get_room_summary_txn) - - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} - - @cached() - def get_invited_rooms_for_user(self, user_id): - """ Get all the rooms the user is invited to - Args: - user_id (str): The user ID. - Returns: - A deferred list of RoomsForUser. - """ - - return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) - - @defer.inlineCallbacks - def get_invite_for_user_in_room(self, user_id, room_id): - """Gets the invite for the given user and room - - Args: - user_id (str) - room_id (str) - - Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. - """ - invites = yield self.get_invited_rooms_for_user(user_id) - for invite in invites: - if invite.room_id == room_id: - return invite - return None - - @defer.inlineCallbacks - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this user where the membership for this user - matches one in the membership list. - - Filters out forgotten rooms. - - Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. - - Returns: - Deferred[list[RoomsForUser]] - """ - if not membership_list: - return defer.succeed(None) - - rooms = yield self.runInteraction( - "get_rooms_for_user_where_membership_is", - self._get_rooms_for_user_where_membership_is_txn, - user_id, - membership_list, - ) - - # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] - - def _get_rooms_for_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): - - do_invite = Membership.INVITE in membership_list - membership_list = [m for m in membership_list if m != Membership.INVITE] - - results = [] - if membership_list: - if self._current_state_events_membership_up_to_date: - sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND c.membership IN (%s) - """ % ( - ",".join("?" * len(membership_list)) - ) - else: - sql = """ - SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND m.membership IN (%s) - """ % ( - ",".join("?" * len(membership_list)) - ) - - txn.execute(sql, (user_id, *membership_list)) - results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] - - if do_invite: - sql = ( - "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM local_invites as i" - " INNER JOIN events as e USING (event_id)" - " WHERE invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute(sql, (user_id,)) - results.extend( - RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) - for r in self.cursor_to_dict(txn) - ) - - return results - - @cachedInlineCallbacks(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): - """Returns a set of room_ids the user is currently joined to - - Args: - user_id (str) - - Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. - """ - rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN] - ) - return frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms - ) - - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): - """Returns a set of room_ids the user is currently joined to - """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate - ) - return frozenset(r.room_id for r in rooms) - - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): - """Returns the set of users who share a room with `user_id` - """ - room_ids = yield self.get_rooms_for_user( - user_id, on_invalidate=cache_context.invalidate - ) - - user_who_share_room = set() - for room_id in room_ids: - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - user_who_share_room.update(user_ids) - - return user_who_share_room - - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield context.get_current_state_ids(self) - result = yield self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, event=event, context=context - ) - return result - - def get_joined_users_from_state(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) - - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( - self, - room_id, - state_group, - current_state_ids, - cache_context, - event=None, - context=None, - ): - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - users_in_room = {} - member_event_ids = [ - e_id - for key, e_id in iteritems(current_state_ids) - if key[0] == EventTypes.Member - ] - - if context is not None: - # If we have a context with a delta from a previous state group, - # check if we also have the result from the previous group in cache. - # If we do then we can reuse that result and simply update it with - # any membership changes in `delta_ids` - if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( - (room_id, context.prev_group), None - ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) - member_event_ids = [ - e_id - for key, e_id in iteritems(context.delta_ids) - if key[0] == EventTypes.Member - ] - for etype, state_key in context.delta_ids: - users_in_room.pop(state_key, None) - - # We check if we have any of the member event ids in the event cache - # before we ask the DB - - # We don't update the event cache hit ratio as it completely throws off - # the hit ratio counts. After all, we don't populate the cache if we - # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) - - missing_member_event_ids = [] - for event_id in member_event_ids: - ev_entry = event_map.get(event_id) - if ev_entry: - if ev_entry.event.membership == Membership.JOIN: - users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( - display_name=to_ascii( - ev_entry.event.content.get("displayname", None) - ), - avatar_url=to_ascii( - ev_entry.event.content.get("avatar_url", None) - ), - ) - else: - missing_member_event_ids.append(event_id) - - if missing_member_event_ids: - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=missing_member_event_ids, - retcols=("user_id", "display_name", "avatar_url"), - keyvalues={"membership": Membership.JOIN}, - batch_size=500, - desc="_get_joined_users_from_context", - ) - - users_in_room.update( - { - to_ascii(row["user_id"]): ProfileInfo( - avatar_url=to_ascii(row["avatar_url"]), - display_name=to_ascii(row["display_name"]), - ) - for row in rows - } - ) - - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room[to_ascii(event.state_key)] = ProfileInfo( - display_name=to_ascii(event.content.get("displayname", None)), - avatar_url=to_ascii(event.content.get("avatar_url", None)), - ) - - return users_in_room - - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT state_key FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = 'join' - AND type = 'm.room.member' - AND c.room_id = ? - AND state_key LIKE ? - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - def get_joined_hosts(self, room_id, state_entry): - state_group = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) - - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - cache = self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts - - @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): - return _JoinedHostsCache(self, room_id) - - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): - """Returns whether user_id has elected to discard history for room_id. - - Returns False if they have since re-joined.""" - - def f(txn): - sql = ( - "SELECT" - " COUNT(*)" - " FROM" - " room_memberships" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - " AND" - " forgotten = 0" - ) - txn.execute(sql, (user_id, room_id)) - rows = txn.fetchall() - return rows[0][0] - - count = yield self.runInteraction("did_forget_membership", f) - return count == 0 - - @cached() - def get_forgotten_rooms_for_user(self, user_id): - """Gets all rooms the user has forgotten. - - Args: - user_id (str) - - Returns: - Deferred[set[str]] - """ - - def _get_forgotten_rooms_for_user_txn(txn): - # This is a slightly convoluted query that first looks up all rooms - # that the user has forgotten in the past, then rechecks that list - # to see if any have subsequently been updated. This is done so that - # we can use a partial index on `forgotten = 1` on the assumption - # that few users will actually forget many rooms. - # - # Note that a room is considered "forgotten" if *all* membership - # events for that user and room have the forgotten field set (as - # when a user forgets a room we update all rows for that user and - # room, not just the current one). - sql = """ - SELECT room_id, ( - SELECT count(*) FROM room_memberships - WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0 - ) AS count - FROM room_memberships AS m - WHERE user_id = ? AND forgotten = 1 - GROUP BY room_id, user_id; - """ - txn.execute(sql, (user_id,)) - return set(row[0] for row in txn if row[1] == 0) - - return self.runInteraction( - "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn - ) - - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): - """Get all rooms that the user has ever been in. - - Args: - user_id (str) - - Returns: - Deferred[set[str]]: Set of room IDs. - """ - - room_ids = yield self._simple_select_onecol( - table="room_memberships", - keyvalues={"membership": Membership.JOIN, "user_id": user_id}, - retcol="room_id", - desc="get_rooms_user_has_been_in", - ) - - return set(room_ids) - - -class RoomMemberStore(RoomMemberWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - self.register_background_update_handler( - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - self._background_current_state_membership, - ) - self.register_background_index_update( - "room_membership_forgotten_idx", - index_name="room_memberships_user_room_forgotten", - table="room_memberships", - columns=["user_id", "room_id"], - where_clause="forgotten = 1", - ) - - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ], - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - - def forget(self, user_id, room_id): - """Indicate that user_id wishes to discard history for room_id.""" - - def f(txn): - sql = ( - "UPDATE" - " room_memberships" - " SET" - " forgotten = 1" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - ) - txn.execute(sql, (user_id, room_id)) - - self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.get_forgotten_rooms_for_user, (user_id,) - ) - - return self.runInteraction("forget_membership", f) - - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): - target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start - ) - max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 - ) - - INSERT_CLUMP_SIZE = 1000 - - def add_membership_profile_txn(txn): - sql = """ - SELECT stream_ordering, event_id, events.room_id, event_json.json - FROM events - INNER JOIN event_json USING (event_id) - INNER JOIN room_memberships USING (event_id) - WHERE ? <= stream_ordering AND stream_ordering < ? - AND type = 'm.room.member' - ORDER BY stream_ordering DESC - LIMIT ? - """ - - txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - - rows = self.cursor_to_dict(txn) - if not rows: - return 0 - - min_stream_id = rows[-1]["stream_ordering"] - - to_update = [] - for row in rows: - event_id = row["event_id"] - room_id = row["room_id"] - try: - event_json = json.loads(row["json"]) - content = event_json["content"] - except Exception: - continue - - display_name = content.get("displayname", None) - avatar_url = content.get("avatar_url", None) - - if display_name or avatar_url: - to_update.append((display_name, avatar_url, event_id, room_id)) - - to_update_sql = """ - UPDATE room_memberships SET display_name = ?, avatar_url = ? - WHERE event_id = ? AND room_id = ? - """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) - - progress = { - "target_min_stream_id_inclusive": target_min_stream_id, - "max_stream_id_exclusive": min_stream_id, - } - - self._background_update_progress_txn( - txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress - ) - - return len(rows) - - result = yield self.runInteraction( - _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn - ) - - if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) - - return result - - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): - """Update the new membership column on current_state_events. - - This works by iterating over all rooms in alphebetical order. - """ - - def _background_current_state_membership_txn(txn, last_processed_room): - processed = 0 - while processed < batch_size: - txn.execute( - """ - SELECT MIN(room_id) FROM current_state_events WHERE room_id > ? - """, - (last_processed_room,), - ) - row = txn.fetchone() - if not row or not row[0]: - return processed, True - - next_room, = row - - sql = """ - UPDATE current_state_events - SET membership = ( - SELECT membership FROM room_memberships - WHERE event_id = current_state_events.event_id - ) - WHERE room_id = ? - """ - txn.execute(sql, (next_room,)) - processed += txn.rowcount - - last_processed_room = next_room - - self._background_update_progress_txn( - txn, - _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, - {"last_processed_room": last_processed_room}, - ) - - return processed, False - - # If we haven't got a last processed room then just use the empty - # string, which will compare before all room IDs correctly. - last_processed_room = progress.get("last_processed_room", "") - - row_count, finished = yield self.runInteraction( - "_background_current_state_membership_update", - _background_current_state_membership_txn, - last_processed_room, - ) - - if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) - - return row_count - - -class _JoinedHostsCache(object): - """Cache for joined hosts in a room that is optimised to handle updates - via state deltas. - """ - - def __init__(self, store, room_id): - self.store = store - self.room_id = room_id - - self.hosts_to_joined_users = {} - - self.state_group = object() - - self.linearizer = Linearizer("_JoinedHostsCache") - - self._len = 0 - - @defer.inlineCallbacks - def get_destinations(self, state_entry): - """Get set of destinations for a state entry - - Args: - state_entry(synapse.state._StateCacheEntry) - """ - if state_entry.state_group == self.state_group: - return frozenset(self.hosts_to_joined_users) - - with (yield self.linearizer.queue(())): - if state_entry.state_group == self.state_group: - pass - elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in iteritems(state_entry.delta_ids): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = self.hosts_to_joined_users.setdefault(host, set()) - - event = yield self.store.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - self.hosts_to_joined_users.pop(host, None) - else: - joined_users = yield self.store.get_joined_users_from_state( - self.room_id, state_entry - ) - - self.hosts_to_joined_users = {} - for user_id in joined_users: - host = intern_string(get_domain_from_id(user_id)) - self.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - self.state_group = state_entry.state_group - else: - self.state_group = object() - self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) - return frozenset(self.hosts_to_joined_users) - - def __len__(self): - return self._len diff --git a/synapse/storage/schema/delta/35/00background_updates_add_col.sql b/synapse/storage/schema/delta/35/00background_updates_add_col.sql new file mode 100644 index 0000000000..c2d2a4f836 --- /dev/null +++ b/synapse/storage/schema/delta/35/00background_updates_add_col.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE background_updates ADD COLUMN depends_on TEXT; diff --git a/synapse/storage/schema/delta/58/00background_update_ordering.sql b/synapse/storage/schema/delta/58/00background_update_ordering.sql new file mode 100644 index 0000000000..02dae587cc --- /dev/null +++ b/synapse/storage/schema/delta/58/00background_update_ordering.sql @@ -0,0 +1,19 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* add an "ordering" column to background_updates, which can be used to sort them + to achieve some level of consistency. */ + +ALTER TABLE background_updates ADD COLUMN ordering INT NOT NULL DEFAULT 0; diff --git a/synapse/storage/schema/full_schemas/54/full.sql b/synapse/storage/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..1005880466 --- /dev/null +++ b/synapse/storage/schema/full_schemas/54/full.sql @@ -0,0 +1,8 @@ + + +CREATE TABLE background_updates ( + update_name text NOT NULL, + progress_json text NOT NULL, + depends_on text, + CONSTRAINT background_updates_uniqueness UNIQUE (update_name) +); diff --git a/synapse/storage/schema/full_schemas/README.txt b/synapse/storage/schema/full_schemas/README.txt deleted file mode 100644 index d3f6401344..0000000000 --- a/synapse/storage/schema/full_schemas/README.txt +++ /dev/null @@ -1,19 +0,0 @@ -Building full schema dumps -========================== - -These schemas need to be made from a database that has had all background updates run. - -Postgres --------- - -$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres - -SQLite ------- - -$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite - -After ------ - -Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 1980a87108..c522c80922 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,43 +14,21 @@ # limitations under the License. import logging -from collections import namedtuple +from typing import Iterable, List, TypeVar from six import iteritems, itervalues -from six.moves import range import attr from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError -from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore -from synapse.storage.engines import PostgresEngine -from synapse.storage.events_worker import EventsWorkerStore -from synapse.util.caches import get_cache_factor_for, intern_string -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache -from synapse.util.stringutils import to_ascii +from synapse.types import StateMap logger = logging.getLogger(__name__) - -MAX_STATE_DELTA_HOPS = 100 - - -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 +# Used for generic functions below +T = TypeVar("T") @attr.s(slots=True) @@ -260,14 +238,14 @@ class StateFilter(object): return len(self.concrete_types()) - def filter_state(self, state_dict): + def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]: """Returns the state filtered with by this StateFilter Args: - state (dict[tuple[str, str], Any]): The state map to filter + state: The state map to filter Returns: - dict[tuple[str, str], Any]: The filtered state map + The filtered state map """ if self.is_full(): return dict(state_dict) @@ -353,252 +331,23 @@ class StateFilter(object): return member_filter, non_member_filter -# this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): - """The parts of StateGroupStore that can be called from workers. +class StateGroupStorage(object): + """High level interface to fetching state for event. """ - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) - - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - - @defer.inlineCallbacks - def get_room_version(self, room_id): - """Get the room_version of a given room - - Args: - room_id (str) - - Returns: - Deferred[str] - - Raises: - NotFoundError if the room is unknown - """ - # for now we do this by looking at the create event. We may want to cache this - # more intelligently in future. - - # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) - return create_event.content.get("room_version", "1") - - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. - Otherwise return None. - - Args: - room_id (str) - - Returns: - Deferred[unicode|None]: predecessor room id - - Raises: - NotFoundError if the room is unknown - """ - # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) - - # Return predecessor if present - return create_event.content.get("predecessor", None) - - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): - """Get the create state event for a room. - - Args: - room_id (str) - - Returns: - Deferred[EventBase]: The room creation event. - - Raises: - NotFoundError if the room is unknown - """ - state_ids = yield self.get_current_state_ids(room_id) - create_id = state_ids.get((EventTypes.Create, "")) - - # If we can't find the create event, assume we've hit a dead end - if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) - - # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) - return create_event - - @cached(max_entries=100000, iterable=True) - def get_current_state_ids(self, room_id): - """Get the current state event ids for a room based on the - current_state_events table. - - Args: - room_id (str) - - Returns: - deferred: dict of (type, state_key) -> event_id - """ - - def _get_current_state_ids_txn(txn): - txn.execute( - """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """, - (room_id,), - ) - - return { - (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn - } - - return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) - - # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): - """Get the current state event of a given type for a room based on the - current_state_events table. This may not be as up-to-date as the result - of doing a fresh state resolution as per state_handler.get_current_state - - Args: - room_id (str) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - Deferred[dict[tuple[str, str], str]]: Map from type/state_key to - event ID. - """ - - where_clause, where_args = state_filter.make_sql_filter_clause() - - if not where_clause: - # We delegate to the cached version - return self.get_current_state_ids(room_id) - - def _get_filtered_current_state_ids_txn(txn): - results = {} - sql = """ - SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? - """ - - if where_clause: - sql += " AND (%s)" % (where_clause,) + def __init__(self, hs, stores): + self.stores = stores - args = [room_id] - args.extend(where_args) - txn.execute(sql, args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id - - return results - - return self.runInteraction( - "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn - ) - - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): - """Get canonical alias for room, if any - - Args: - room_id (str) - - Returns: - Deferred[str|None]: The canonical alias, if any - """ - - state = yield self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return - - event = yield self.get_event(event_id, allow_none=True) - if not event: - return - - return event.content.get("canonical_alias") - - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): + def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. Returns: - (prev_group, delta_ids), where both may be None. + Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: + (prev_group, delta_ids) """ - def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self._simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + return self.stores.state.get_state_group_delta(state_group) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -609,16 +358,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): event_ids (iterable[str]): ids of the events Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: + Deferred[dict[int, StateMap[str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self._get_state_group_for_events(event_ids) + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) + group_to_state = yield self.stores.state._get_state_for_groups(groups) return group_to_state @@ -639,7 +388,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - Returns: Deferred[dict[int, list[EventBase]]]: dict of state_group_id -> list of state events. @@ -649,7 +397,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.get_events( + state_event_map = yield self.stores.main.get_events( [ ev_id for group_ids in itervalues(group_to_ids) @@ -667,153 +415,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for group, event_id_map in iteritems(group_to_ids) } - @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): """Returns the state groups for a given set of groups, filtering on types of state events. Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state + groups: list of state group IDs to query + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - return results - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results + return self.stores.state._get_state_groups_from_groups(groups, state_filter) @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state dicts for each event. - Args: event_ids (list[string]) state_filter (StateFilter): The state filter used to fetch state from the database. - Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self._get_state_group_for_events(event_ids) + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) + group_to_state = yield self.stores.state._get_state_for_groups( + groups, state_filter + ) - state_event_map = yield self.get_events( + state_event_map = yield self.stores.main.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], get_prev_content=False, ) @@ -843,10 +479,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self._get_state_group_for_events(event_ids) + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) + group_to_state = yield self.stores.state._get_state_for_groups( + groups, state_filter + ) event_to_state = { event_id: group_to_state[group] @@ -887,77 +525,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): state_map = yield self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] - @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( - table="event_to_state_groups", - keyvalues={"event_id": event_id}, - retcol="state_group", - allow_none=True, - desc="_get_state_group_for_event", - ) - - @cachedList( - cached_method_name="_get_state_group_for_event", - list_name="event_ids", - num_args=1, - inlineCallbacks=True, - ) - def _get_state_group_for_events(self, event_ids): - """Returns mapping event_id -> state_group - """ - rows = yield self._simple_select_many_batch( - table="event_to_state_groups", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("event_id", "state_group"), - desc="_get_state_group_for_events", - ) - - return {row["event_id"]: row["state_group"] for row in rows} - - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -967,157 +537,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): state_filter (StateFilter): The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - non_member_state, incomplete_groups_nm, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - ) - - member_state, incomplete_groups_m, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) + return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids @@ -1137,393 +559,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: Deferred[int]: The state group ID """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self._simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self._simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.runInteraction("store_state_group", _store_state_group_txn) - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - -class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. - """ - - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", - ) - self.register_background_index_update( - self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, - index_name="event_to_state_groups_sg_index", - table="event_to_state_groups", - columns=["state_group"], - ) - - def _store_event_state_mappings_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) - ], - ) - - for event_id, state_group_id in iteritems(state_groups): - txn.call_after( - self._get_state_group_for_event.prefill, (event_id,), state_group_id - ) - - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self._execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - prev_group, = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self._simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self._simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + return self.stores.state.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids ) - - if finished: - yield self._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.runWithConnection(reindex_txn) - - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 diff --git a/synapse/storage/types.py b/synapse/storage/types.py new file mode 100644 index 0000000000..daff81c5ee --- /dev/null +++ b/synapse/storage/types.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterable, Iterator, List, Tuple + +from typing_extensions import Protocol + + +""" +Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 +""" + + +class Cursor(Protocol): + def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + ... + + def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + ... + + def fetchall(self) -> List[Tuple]: + ... + + def fetchone(self) -> Tuple: + ... + + @property + def description(self) -> Any: + return None + + @property + def rowcount(self) -> int: + return 0 + + def __iter__(self) -> Iterator[Tuple]: + ... + + def close(self) -> None: + ... + + +class Connection(Protocol): + def cursor(self) -> Cursor: + ... + + def close(self) -> None: + ... + + def commit(self) -> None: + ... + + def rollback(self, *args, **kwargs) -> None: + ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index cbb0a4810a..f89ce0bed2 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -16,6 +16,11 @@ import contextlib import threading from collections import deque +from typing import Dict, Set, Tuple + +from typing_extensions import Deque + +from synapse.storage.database import Database, LoggingTransaction class IdGenerator(object): @@ -46,7 +51,7 @@ def _load_current_id(db_conn, table, column, step=1): cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) - val, = cur.fetchone() + (val,) = cur.fetchone() cur.close() current_id = int(val) if val else step return (max if step > 0 else min)(current_id, step) @@ -87,7 +92,7 @@ class StreamIdGenerator(object): self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[int] def get_next(self): """ @@ -161,9 +166,10 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator + self._table = table self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] def get_next(self): """ @@ -198,3 +204,173 @@ class ChainedIdGenerator(object): return stream_id - 1, chained_id return self._current_max, self.chained_generator.get_current_token() + + def advance(self, token: int): + """Stub implementation for advancing the token when receiving updates + over replication; raises an exception as this instance should be the + only source of updates. + """ + + raise Exception( + "Attempted to advance token on source for table %r", self._table + ) + + +class MultiWriterIdGenerator: + """An ID generator that tracks a stream that can have multiple writers. + + Uses a Postgres sequence to coordinate ID assignment, but positions of other + writers will only get updated when `advance` is called (by replication). + + Note: Only works with Postgres. + + Args: + db_conn + db + instance_name: The name of this instance. + table: Database table associated with stream. + instance_column: Column that stores the row's writer's instance name + id_column: Column that stores the stream ID. + sequence_name: The name of the postgres sequence used to generate new + IDs. + """ + + def __init__( + self, + db_conn, + db: Database, + instance_name: str, + table: str, + instance_column: str, + id_column: str, + sequence_name: str, + ): + self._db = db + self._instance_name = instance_name + self._sequence_name = sequence_name + + # We lock as some functions may be called from DB threads. + self._lock = threading.Lock() + + self._current_positions = self._load_current_ids( + db_conn, table, instance_column, id_column + ) + + # Set of local IDs that we're still processing. The current position + # should be less than the minimum of this set (if not empty). + self._unfinished_ids = set() # type: Set[int] + + def _load_current_ids( + self, db_conn, table: str, instance_column: str, id_column: str + ) -> Dict[str, int]: + sql = """ + SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + GROUP BY %(instance)s + """ % { + "instance": instance_column, + "id": id_column, + "table": table, + } + + cur = db_conn.cursor() + cur.execute(sql) + + # `cur` is an iterable over returned rows, which are 2-tuples. + current_positions = dict(cur) + + cur.close() + + return current_positions + + def _load_next_id_txn(self, txn): + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + (next_id,) = txn.fetchone() + return next_id + + async def get_next(self): + """ + Usage: + with await stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) + + # Assert the fetched ID is actually greater than what we currently + # believe the ID to be. If not, then the sequence and table have got + # out of sync somehow. + assert self.get_current_token() < next_id + + with self._lock: + self._unfinished_ids.add(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_id + finally: + self._mark_id_as_finished(next_id) + + return manager() + + def get_next_txn(self, txn: LoggingTransaction): + """ + Usage: + + stream_id = stream_id_gen.get_next(txn) + # ... persist event ... + """ + + next_id = self._load_next_id_txn(txn) + + with self._lock: + self._unfinished_ids.add(next_id) + + txn.call_after(self._mark_id_as_finished, next_id) + txn.call_on_exception(self._mark_id_as_finished, next_id) + + return next_id + + def _mark_id_as_finished(self, next_id: int): + """The ID has finished being processed so we should advance the + current poistion if possible. + """ + + with self._lock: + self._unfinished_ids.discard(next_id) + + # Figure out if its safe to advance the position by checking there + # aren't any lower allocated IDs that are yet to finish. + if all(c > next_id for c in self._unfinished_ids): + curr = self._current_positions.get(self._instance_name, 0) + self._current_positions[self._instance_name] = max(curr, next_id) + + def get_current_token(self, instance_name: str = None) -> int: + """Gets the current position of a named writer (defaults to current + instance). + + Returns 0 if we don't have a position for the named writer (likely due + to it being a new writer). + """ + + if instance_name is None: + instance_name = self._instance_name + + with self._lock: + return self._current_positions.get(instance_name, 0) + + def get_positions(self) -> Dict[str, int]: + """Get a copy of the current positon map. + """ + + with self._lock: + return dict(self._current_positions) + + def advance(self, instance_name: str, new_id: int): + """Advance the postion of the named writer to the given ID, if greater + than existing entry. + """ + + with self._lock: + self._current_positions[instance_name] = max( + new_id, self._current_positions.get(instance_name, 0) + ) diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 02994ab2a5..cd56cd91ed 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -88,9 +88,12 @@ class PaginationConfig(object): raise SynapseError(400, "Invalid request.") def __repr__(self): - return ( - "PaginationConfig(from_tok=%r, to_tok=%r," " direction=%r, limit=%r)" - ) % (self.from_token, self.to_token, self.direction, self.limit) + return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % ( + self.from_token, + self.to_token, + self.direction, + self.limit, + ) def get_source_config(self, source_name): keyname = "%s_key" % source_name diff --git a/synapse/streams/events.py b/synapse/streams/events.py index b91fb2db7b..fcd2aaa9c9 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict + from twisted.internet import defer from synapse.handlers.account_data import AccountDataEventSource @@ -35,7 +37,7 @@ class EventSources(object): def __init__(self, hs): self.sources = { name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() - } + } # type: Dict[str, Any] self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/synapse/types.py b/synapse/types.py index 51eadb6ad4..acf60baddc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +15,37 @@ # limitations under the License. import re import string +import sys from collections import namedtuple +from typing import Any, Dict, Tuple, TypeVar import attr +from signedjson.key import decode_verify_key_bytes +from unpaddedbase64 import decode_base64 -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError + +# define a version of typing.Collection that works on python 3.5 +if sys.version_info[:3] >= (3, 6, 0): + from typing import Collection +else: + from typing import Sized, Iterable, Container + + T_co = TypeVar("T_co", covariant=True) + + class Collection(Iterable[T_co], Container[T_co], Sized): + __slots__ = () + + +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateMap = Dict[Tuple[str, str], T] + + +# the type of a JSON-serialisable dict. This could be made stronger, but it will +# do for now. +JsonDict = Dict[str, Any] class Requester( @@ -139,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str): """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( - 400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL) + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) parts = s[1:].split(":", 1) @@ -152,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 400, "Expected %s of the form '%slocalname:domain'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) domain = parts[1] @@ -208,11 +238,13 @@ class GroupID(DomainSpecificString): def from_string(cls, s): group_id = super(GroupID, cls).from_string(s) if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty") + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) if contains_invalid_mxid_characters(group_id.localpart): raise SynapseError( - 400, "Group ID can only contain characters a-z, 0-9, or '=_-./'" + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + Codes.INVALID_PARAM, ) return group_id @@ -318,6 +350,7 @@ class StreamToken( ) ): _SEPARATOR = "_" + START = None # type: StreamToken @classmethod def from_string(cls, string): @@ -402,7 +435,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): followed by the "stream_ordering" id of the event it comes after. """ - __slots__ = [] + __slots__ = [] # type: list @classmethod def parse(cls, string): @@ -475,3 +508,24 @@ class ReadReceipt(object): user_id = attr.ib() event_ids = attr.ib() data = attr.ib() + + +def get_verify_key_from_cross_signing_key(key_info): + """Get the key ID and signedjson verify key from a cross-signing key dict + + Args: + key_info (dict): a cross-signing key dict, which must have a "keys" + property that has exactly one item in it + + Returns: + (str, VerifyKey): the key ID and verify key for the cross-signing key + """ + # make sure that exactly one key is provided + if "keys" not in key_info: + raise ValueError("Invalid key") + keys = key_info["keys"] + if len(keys) != 1: + raise ValueError("Invalid key") + # and return that one key + for key_id, key_data in keys.items(): + return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 7856353002..60f0de70f7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,7 +15,6 @@ import logging import re -from itertools import islice import attr @@ -107,22 +106,6 @@ class Clock(object): raise -def batch_iter(iterable, size): - """batch an iterable up into tuples with a maximum size - - Args: - iterable (iterable): the iterable to slice - size (int): the maximum batch size - - Returns: - an iterator over the chunks - """ - # make sure we can deal with iterables like lists too - sourceiter = iter(iterable) - # call islice until it returns an empty tuple - return iter(lambda: tuple(islice(sourceiter, size)), ()) - - def log_failure(failure, msg, consumeErrors=True): """Creates a function suitable for passing to `Deferred.addErrback` that logs any failures that occur. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f1c46836b1..f7af2bca7f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -13,12 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import collections import logging from contextlib import contextmanager +from typing import Dict, Sequence, Set, Union from six.moves import range +import attr + from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.python import failure @@ -69,6 +73,10 @@ class ObservableDeferred(object): def errback(f): object.__setattr__(self, "_result", (False, f)) while self._observers: + # This is a little bit of magic to correctly propagate stack + # traces when we `await` on one of the observer deferreds. + f.value.__failure__ = f + try: # TODO: Handle errors here. self._observers.pop().errback(f) @@ -82,11 +90,12 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) - def observe(self): + def observe(self) -> defer.Deferred: """Observe the underlying deferred. - Can return either a deferred if the underlying deferred is still pending - (or has failed), or the actual value. Callers may need to use maybeDeferred. + This returns a brand new deferred that is resolved when the underlying + deferred is resolved. Interacting with the returned deferred does not + effect the underdlying deferred. """ if not self._result: d = defer.Deferred() @@ -101,7 +110,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return res if success else defer.fail(res) + return defer.succeed(res) if success else defer.fail(res) def observers(self): return self._observers @@ -134,9 +143,9 @@ def concurrently_execute(func, args, limit): the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. + func (func): Function to execute, should return a deferred or coroutine. + args (Iterable): List of arguments to pass to func, each invocation of func + gets a single argument. limit (int): Maximum number of conccurent executions. Returns: @@ -144,11 +153,10 @@ def concurrently_execute(func, args, limit): """ it = iter(args) - @defer.inlineCallbacks - def _concurrently_execute_inner(): + async def _concurrently_execute_inner(): try: while True: - yield func(next(it)) + await maybe_awaitable(func(next(it))) except StopIteration: pass @@ -213,7 +221,21 @@ class Linearizer(object): # the first element is the number of things executing, and # the second element is an OrderedDict, where the keys are deferreds for the # things blocked from executing. - self.key_to_defer = {} + self.key_to_defer = ( + {} + ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] + + def is_queued(self, key) -> bool: + """Checks whether there is a process queued up waiting + """ + entry = self.key_to_defer.get(key) + if not entry: + # No entry so nothing is waiting. + return False + + # There are waiting deferreds only in the OrderedDict of deferreds is + # non-empty. + return bool(entry[1]) def queue(self, key): # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. @@ -303,7 +325,7 @@ class Linearizer(object): ) else: - logger.warn( + logger.warning( "Unexpected exception waiting for linearizer lock %r for key %r", self.name, key, @@ -340,10 +362,10 @@ class ReadWriteLock(object): def __init__(self): # Latest readers queued - self.key_to_current_readers = {} + self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]] # Latest writer queued - self.key_to_current_writer = {} + self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] @defer.inlineCallbacks def read(self, key): @@ -479,3 +501,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred.addCallbacks(success_cb, failure_cb) return new_d + + +@attr.s(slots=True, frozen=True) +class DoneAwaitable(object): + """Simple awaitable that returns the provided value. + """ + + value = attr.ib() + + def __await__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + raise StopIteration(self.value) + + +def maybe_awaitable(value): + """Convert a value to an awaitable if not already an awaitable. + """ + + if hasattr(value, "__await__"): + return value + + return DoneAwaitable(value) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index b50e3503f0..dd356bf156 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,34 +15,24 @@ # limitations under the License. import logging -import os +from sys import intern +from typing import Callable, Dict, Optional -import six -from six.moves import intern +import attr +from prometheus_client.core import Gauge -from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily +from synapse.config.cache import add_resizable_cache logger = logging.getLogger(__name__) -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) - - -def get_cache_factor_for(cache_name): - env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() - factor = os.environ.get(env_var) - if factor: - return float(factor) - - return CACHE_SIZE_FACTOR - - caches_by_name = {} -collectors_by_name = {} +collectors_by_name = {} # type: Dict cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) +cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) @@ -52,67 +42,82 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -def register_cache(cache_type, cache_name, cache, collect_callback=None): - """Register a cache object for metric collection. +@attr.s +class CacheMetric(object): + + _cache = attr.ib() + _cache_type = attr.ib(type=str) + _cache_name = attr.ib(type=str) + _collect_callback = attr.ib(type=Optional[Callable]) + + hits = attr.ib(default=0) + misses = attr.ib(default=0) + evicted_size = attr.ib(default=0) + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + try: + if self._cache_type == "response_cache": + response_cache_size.labels(self._cache_name).set(len(self._cache)) + response_cache_hits.labels(self._cache_name).set(self.hits) + response_cache_evicted.labels(self._cache_name).set(self.evicted_size) + response_cache_total.labels(self._cache_name).set( + self.hits + self.misses + ) + else: + cache_size.labels(self._cache_name).set(len(self._cache)) + cache_hits.labels(self._cache_name).set(self.hits) + cache_evicted.labels(self._cache_name).set(self.evicted_size) + cache_total.labels(self._cache_name).set(self.hits + self.misses) + if getattr(self._cache, "max_size", None): + cache_max_size.labels(self._cache_name).set(self._cache.max_size) + if self._collect_callback: + self._collect_callback() + except Exception as e: + logger.warning("Error calculating metrics for %s: %s", self._cache_name, e) + raise + + +def register_cache( + cache_type: str, + cache_name: str, + cache, + collect_callback: Optional[Callable] = None, + resizable: bool = True, + resize_callback: Optional[Callable] = None, +) -> CacheMetric: + """Register a cache object for metric collection and resizing. Args: - cache_type (str): - cache_name (str): name of the cache - cache (object): cache itself - collect_callback (callable|None): if not None, a function which is called during - metric collection to update additional metrics. + cache_type + cache_name: name of the cache + cache: cache itself + collect_callback: If given, a function which is called during metric + collection to update additional metrics. + resizable: Whether this cache supports being resized. + resize_callback: A function which can be called to resize the cache. Returns: CacheMetric: an object which provides inc_{hits,misses,evictions} methods """ + if resizable: + if not resize_callback: + resize_callback = getattr(cache, "set_cache_factor") + add_resizable_cache(cache_name, resize_callback) - # Check if the metric is already registered. Unregister it, if so. - # This usually happens during tests, as at runtime these caches are - # effectively singletons. + metric = CacheMetric(cache, cache_type, cache_name, collect_callback) metric_name = "cache_%s_%s" % (cache_type, cache_name) - if metric_name in collectors_by_name.keys(): - REGISTRY.unregister(collectors_by_name[metric_name]) - - class CacheMetric(object): - - hits = 0 - misses = 0 - evicted_size = 0 - - def inc_hits(self): - self.hits += 1 - - def inc_misses(self): - self.misses += 1 - - def inc_evictions(self, size=1): - self.evicted_size += size - - def describe(self): - return [] - - def collect(self): - try: - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) - if collect_callback: - collect_callback() - except Exception as e: - logger.warn("Error calculating metrics for %s: %s", cache_name, e) - raise - - yield GaugeMetricFamily("__unused", "") - - metric = CacheMetric() - REGISTRY.register(metric) caches_by_name[cache_name] = cache collectors_by_name[metric_name] = metric return metric @@ -147,9 +152,6 @@ def intern_string(string): return None try: - if six.PY2: - string = string.encode("ascii") - return intern(string) except UnicodeEncodeError: return string diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 43f66ec4be..cd48262420 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,22 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import inspect import logging import threading -from collections import namedtuple +from typing import Any, Tuple, Union, cast +from weakref import WeakValueDictionary from six import itervalues from prometheus_client import Gauge +from typing_extensions import Protocol from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -36,6 +38,20 @@ from . import register_cache logger = logging.getLogger(__name__) +CacheKey = Union[Tuple, Any] + + +class _CachedFunction(Protocol): + invalidate = None # type: Any + invalidate_all = None # type: Any + invalidate_many = None # type: Any + prefill = None # type: Any + cache = None # type: Any + num_args = None # type: Any + + def __name__(self): + ... + cache_pending_metric = Gauge( "synapse_util_caches_cache_pending", @@ -65,7 +81,6 @@ class CacheEntry(object): class Cache(object): __slots__ = ( "cache", - "max_entries", "name", "keylen", "thread", @@ -73,7 +88,29 @@ class Cache(object): "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + + Returns: + Cache + """ cache_type = TreeCache if tree else dict self._pending_deferred_cache = cache_type() @@ -83,6 +120,7 @@ class Cache(object): cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, ) self.name = name @@ -95,6 +133,10 @@ class Cache(object): collect_callback=self._metrics_collection_callback, ) + @property + def max_entries(self): + return self.cache.max_size + def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) @@ -245,7 +287,9 @@ class Cache(object): class _CacheDescriptorBase(object): - def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): + def __init__( + self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False + ): self.orig = orig if inlineCallbacks: @@ -253,7 +297,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: @@ -352,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, ) - max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) - self.max_entries = max_entries self.tree = tree self.iterable = iterable - def __get__(self, obj, objtype=None): + def __get__(self, obj, owner): cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, @@ -404,7 +446,7 @@ class CacheDescriptor(_CacheDescriptorBase): return tuple(get_cache_key_gen(args, kwargs)) @functools.wraps(self.orig) - def wrapped(*args, **kwargs): + def _wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -414,7 +456,7 @@ class CacheDescriptor(_CacheDescriptorBase): # Add our own `cache_context` to argument list if the wrapped function # has asked for one if self.add_cache_context: - kwargs["cache_context"] = _CacheContext(cache, cache_key) + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) try: cached_result_d = cache.get(cache_key, callback=invalidate_callback) @@ -422,7 +464,7 @@ class CacheDescriptor(_CacheDescriptorBase): if isinstance(cached_result_d, ObservableDeferred): observer = cached_result_d.observe() else: - observer = cached_result_d + observer = defer.succeed(cached_result_d) except KeyError: ret = defer.maybeDeferred( @@ -440,6 +482,8 @@ class CacheDescriptor(_CacheDescriptorBase): return make_deferred_yieldable(observer) + wrapped = cast(_CachedFunction, _wrapped) + if self.num_args == 1: wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.prefill = lambda key, val: cache.prefill(key[0], val) @@ -464,9 +508,8 @@ class CacheListDescriptor(_CacheDescriptorBase): Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped function. - Once wrapped, the function returns either a Deferred which resolves to - the list of results, or (if all results were cached), just the list of - results. + Once wrapped, the function returns a Deferred which resolves to the list + of results. """ def __init__( @@ -600,21 +643,45 @@ class CacheListDescriptor(_CacheDescriptorBase): ) return make_deferred_yieldable(d) else: - return results + return defer.succeed(results) obj.__dict__[self.orig.__name__] = wrapped return wrapped -class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. - def invalidate(self): - self.cache.invalidate(self.key) +class _CacheContext: + """Holds cache information from the cached function higher in the calling order. + + Can be used to invalidate the higher level cache entry if something changes + on a lower level. + """ + + _cache_context_objects = ( + WeakValueDictionary() + ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + + def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + self._cache = cache + self._cache_key = cache_key + + def invalidate(self): # type: () -> None + """Invalidates the cache entry referred to by the context.""" + self._cache.invalidate(self._cache_key) + + @classmethod + def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + """Returns an instance constructed with the given arguments. + + A new instance is only created if none already exists. + """ + + # We make sure there are no identical _CacheContext instances. This is + # important in particular to dedupe when we add callbacks to lru cache + # nodes, otherwise the number of callbacks would grow. + return cls._cache_context_objects.setdefault( + (cache, cache_key), cls(cache, cache_key) + ) def cached( diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index cddf1ed515..2726b67b6d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -18,6 +18,7 @@ from collections import OrderedDict from six import iteritems, itervalues +from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -51,15 +52,16 @@ class ExpiringCache(object): an item on access. Defaults to False. iterable (bool): If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. - """ self._cache_name = cache_name + self._original_max_size = max_len + + self._max_size = int(max_len * cache_config.properties.default_factor_size) + self._clock = clock - self._max_len = max_len self._expiry_ms = expiry_ms - self._reset_expiry_on_get = reset_expiry_on_get self._cache = OrderedDict() @@ -82,9 +84,11 @@ class ExpiringCache(object): def __setitem__(self, key, value): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + self.evict() + def evict(self): # Evict if there are now too many items - while self._max_len and len(self) > self._max_len: + while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: self.metrics.inc_evictions(len(value.value)) @@ -170,6 +174,23 @@ class ExpiringCache(object): else: return len(self._cache) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self._max_size: + self._max_size = new_size + self.evict() + return True + return False + class _CacheEntry(object): __slots__ = ["time", "value"] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1536cb64f3..df4ea5901d 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading from functools import wraps +from typing import Callable, Optional, Type, Union +from synapse.config import cache as cache_config from synapse.util.caches.treecache import TreeCache @@ -52,17 +53,18 @@ class LruCache(object): def __init__( self, - max_size, - keylen=1, - cache_type=dict, - size_callback=None, - evicted_callback=None, + max_size: int, + keylen: int = 1, + cache_type: Type[Union[dict, TreeCache]] = dict, + size_callback: Optional[Callable] = None, + evicted_callback: Optional[Callable] = None, + apply_cache_factor_from_config: bool = True, ): """ Args: - max_size (int): + max_size: The maximum amount of entries the cache can hold - keylen (int): + keylen: The length of the tuple used as the cache key cache_type (type): type of underlying cache to be used. Typically one of dict @@ -73,9 +75,24 @@ class LruCache(object): evicted_callback (func(int)|None): if not None, called on eviction with the size of the evicted entry + + apply_cache_factor_from_config (bool): If true, `max_size` will be + multiplied by a cache factor derived from the homeserver config """ cache = cache_type() self.cache = cache # Used for introspection. + self.apply_cache_factor_from_config = apply_cache_factor_from_config + + # Save the original max size, and apply the default size factor. + self._original_max_size = max_size + # We previously didn't apply the cache factor here, and as such some caches were + # not affected by the global cache factor. Add an option here to disable applying + # the cache factor when a cache is created + if apply_cache_factor_from_config: + self.max_size = int(max_size * cache_config.properties.default_factor_size) + else: + self.max_size = int(max_size) + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -83,7 +100,7 @@ class LruCache(object): lock = threading.Lock() def evict(): - while cache_len() > max_size: + while cache_len() > self.max_size: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) @@ -236,6 +253,7 @@ class LruCache(object): return key in cache self.sentinel = object() + self._on_resize = evict self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -266,3 +284,23 @@ class LruCache(object): def __contains__(self, key): return self.contains(key) + + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + if not self.apply_cache_factor_from_config: + return False + + new_size = int(self._original_max_size * factor) + if new_size != self.max_size: + self.max_size = new_size + self._on_resize() + return True + return False diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 82d3eefe0e..a6c60888e5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -38,7 +38,7 @@ class ResponseCache(object): self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache("response_cache", name, self) + self._metrics = register_cache("response_cache", name, self, resizable=False) def size(self): return len(self.pending_result_cache) @@ -144,7 +144,7 @@ class ResponseCache(object): """ result = self.get(key) if not result: - logger.info( + logger.debug( "[%s]: no cached result for [%s], calculating new one", self._name, key ) d = run_in_background(callback, *args, **kwargs) diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 235f64049c..2a161bf244 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,17 +14,23 @@ # limitations under the License. import logging +import math +from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types from sortedcontainers import SortedDict +from synapse.types import Collection from synapse.util import caches logger = logging.getLogger(__name__) +# for now, assume all entities in the cache are strings +EntityType = str -class StreamChangeCache(object): + +class StreamChangeCache: """Keeps track of the stream positions of the latest change in a set of entities. Typically the entity will be a room or user id. @@ -34,19 +40,52 @@ class StreamChangeCache(object): old then the cache will simply return all given entities. """ - def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None): - self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) - self._entity_to_key = {} - self._cache = SortedDict() + def __init__( + self, + name: str, + current_stream_pos: int, + max_size=10000, + prefilled_cache: Optional[Mapping[EntityType, int]] = None, + ): + self._original_max_size = max_size + self._max_size = math.floor(max_size) + self._entity_to_key = {} # type: Dict[EntityType, int] + + # map from stream id to the a set of entities which changed at that stream id. + self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]] + + # the earliest stream_pos for which we can reliably answer + # get_all_entities_changed. In other words, one less than the earliest + # stream_pos for which we know _cache is valid. + # self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = caches.register_cache("cache", self.name, self._cache) + self.metrics = caches.register_cache( + "cache", self.name, self._cache, resize_callback=self.set_cache_factor + ) if prefilled_cache: for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) - def has_entity_changed(self, entity, stream_pos): + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = math.floor(self._original_max_size * factor) + if new_size != self._max_size: + self.max_size = new_size + self._evict() + return True + return False + + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ assert type(stream_pos) in integer_types @@ -67,22 +106,27 @@ class StreamChangeCache(object): self.metrics.inc_hits() return False - def get_entities_changed(self, entities, stream_pos): + def get_entities_changed( + self, entities: Collection[EntityType], stream_pos: int + ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ Returns subset of entities that have had new things since the given position. Entities unknown to the cache will be returned. If the position is too old it will just return the given list. """ - assert type(stream_pos) is int - - if stream_pos >= self._earliest_known_stream_pos: - changed_entities = { - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - } - - result = changed_entities.intersection(entities) - + changed_entities = self.get_all_entities_changed(stream_pos) + if changed_entities is not None: + # We now do an intersection, trying to do so in the most efficient + # way possible (some of these sets are *large*). First check in the + # given iterable is already set that we can reuse, otherwise we + # create a set of the *smallest* of the two iterables and call + # `intersection(..)` on it (this can be twice as fast as the reverse). + if isinstance(entities, (set, frozenset)): + result = entities.intersection(changed_entities) + elif len(changed_entities) < len(entities): + result = set(changed_entities).intersection(entities) + else: + result = set(entities).intersection(changed_entities) self.metrics.inc_hits() else: result = set(entities) @@ -90,13 +134,13 @@ class StreamChangeCache(object): return result - def has_any_entity_changed(self, stream_pos): + def has_any_entity_changed(self, stream_pos: int) -> bool: """Returns if any entity has changed """ assert type(stream_pos) is int if not self._cache: - # If we have no cache, nothing can have changed. + # If the cache is empty, nothing can have changed. return False if stream_pos >= self._earliest_known_stream_pos: @@ -106,42 +150,66 @@ class StreamChangeCache(object): self.metrics.inc_misses() return True - def get_all_entities_changed(self, stream_pos): - """Returns all entites that have had new things since the given + def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + """Returns all entities that have had new things since the given position. If the position is too old it will return None. + + Returns the entities in the order that they were changed. """ assert type(stream_pos) is int - if stream_pos >= self._earliest_known_stream_pos: - return [ - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - ] - else: + if stream_pos < self._earliest_known_stream_pos: return None - def entity_has_changed(self, entity, stream_pos): + changed_entities = [] # type: List[EntityType] + + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): + changed_entities.extend(self._cache[k]) + return changed_entities + + def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """Informs the cache that the entity has been changed at the given position. """ assert type(stream_pos) is int - if stream_pos > self._earliest_known_stream_pos: - old_pos = self._entity_to_key.get(entity, None) - if old_pos is not None: - stream_pos = max(stream_pos, old_pos) - self._cache.pop(old_pos, None) - self._cache[stream_pos] = entity - self._entity_to_key[entity] = stream_pos - - while len(self._cache) > self._max_size: - k, r = self._cache.popitem(0) - self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos - ) - self._entity_to_key.pop(r, None) - - def get_max_pos_of_last_change(self, entity): + if stream_pos <= self._earliest_known_stream_pos: + return + + old_pos = self._entity_to_key.get(entity, None) + if old_pos is not None: + if old_pos >= stream_pos: + # nothing to do + return + e = self._cache[old_pos] + e.remove(entity) + if not e: + # cache at this point is now empty + del self._cache[old_pos] + + e1 = self._cache.get(stream_pos) + if e1 is None: + e1 = self._cache[stream_pos] = set() + e1.add(entity) + self._entity_to_key[entity] = stream_pos + self._evict() + + # if the cache is too big, remove entries + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + del self._entity_to_key[entity] + + def _evict(self): + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + self._entity_to_key.pop(entity, None) + + def get_max_pos_of_last_change(self, entity: EntityType) -> int: + """Returns an upper bound of the stream id of the last change to an entity. """ diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 9a72218d85..2ea4e4e911 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -1,3 +1,5 @@ +from typing import Dict + from six import itervalues SENTINEL = object() @@ -12,7 +14,7 @@ class TreeCache(object): def __init__(self): self.size = 0 - self.root = {} + self.root = {} # type: Dict def __setitem__(self, key, value): return self.set(key, value) diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 99646c7cf0..6437aa907e 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -38,7 +38,7 @@ class TTLCache(object): self._timer = timer - self._metrics = register_cache("ttl", cache_name, self) + self._metrics = register_cache("ttl", cache_name, self, resizable=False) def set(self, key, value, ttl): """Add/update an entry in the cache diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 635b897d6c..9815bb8667 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -30,7 +30,7 @@ def freeze(o): return o try: - return tuple([freeze(i) for i in o]) + return tuple(freeze(i) for i in o) except TypeError: pass @@ -65,5 +65,5 @@ def _handle_frozendict(obj): ) -# A JSONEncoder which is capable of encoding frozendics without barfing +# A JSONEncoder which is capable of encoding frozendicts without barfing frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 1a20c596bf..3c0e8469f3 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) def create_resource_tree(desired_tree, root_resource): - """Create the resource tree for this Home Server. + """Create the resource tree for this homeserver. This in unduly complicated because Twisted does not support putting child resources more than 1 level deep at a time. diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py new file mode 100644 index 0000000000..06faeebe7f --- /dev/null +++ b/synapse/util/iterutils.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import islice +from typing import Iterable, Iterator, Sequence, Tuple, TypeVar + +T = TypeVar("T") + + +def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: + """batch an iterable up into tuples with a maximum size + + Args: + iterable (iterable): the iterable to slice + size (int): the maximum batch size + + Returns: + an iterator over the chunks + """ + # make sure we can deal with iterables like lists too + sourceiter = iter(iterable) + # call islice until it returns an empty tuple + return iter(lambda: tuple(islice(sourceiter, size)), ()) + + +ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) + + +def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: + """Split the given sequence into chunks of the given size + + The last chunk may be shorter than the given size. + + If the input is empty, no chunks are returned. + """ + return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 0910930c21..ec61e14423 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from functools import wraps @@ -20,7 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -60,14 +61,26 @@ in_flight = InFlightGauge( ) -def measure_func(name): +def measure_func(name=None): def wrapper(func): - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, name): - r = yield func(self, *args, **kwargs) - return r + block_name = func.__name__ if name is None else name + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r + + else: + + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = yield func(self, *args, **kwargs) + return r return measured_func @@ -78,72 +91,48 @@ class Measure(object): __slots__ = [ "clock", "name", - "start_context", + "_logging_context", "start", - "created_context", - "start_usage", ] def __init__(self, clock, name): self.clock = clock self.name = name - self.start_context = None + self._logging_context = None self.start = None - self.created_context = False def __enter__(self): - self.start = self.clock.time() - self.start_context = LoggingContext.current_context() - if not self.start_context: - self.start_context = LoggingContext("Measure") - self.start_context.__enter__() - self.created_context = True - - self.start_usage = self.start_context.get_resource_usage() + if self._logging_context: + raise RuntimeError("Measure() objects cannot be re-used") + self.start = self.clock.time() + parent_context = current_context() + self._logging_context = LoggingContext( + "Measure[%s]" % (self.name,), parent_context + ) + self._logging_context.__enter__() in_flight.register((self.name,), self._update_in_flight) def __exit__(self, exc_type, exc_val, exc_tb): - if isinstance(exc_type, Exception) or not self.start_context: - return - - in_flight.unregister((self.name,), self._update_in_flight) + if not self._logging_context: + raise RuntimeError("Measure() block exited without being entered") duration = self.clock.time() - self.start + usage = self._logging_context.get_resource_usage() - block_counter.labels(self.name).inc() - block_timer.labels(self.name).inc(duration) - - context = LoggingContext.current_context() - - if context != self.start_context: - logger.warn( - "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, - context, - self.name, - ) - return - - if not context: - logger.warn("Expected context. (%r)", self.name) - return + in_flight.unregister((self.name,), self._update_in_flight) + self._logging_context.__exit__(exc_type, exc_val, exc_tb) - current = context.get_resource_usage() - usage = current - self.start_usage try: + block_counter.labels(self.name).inc() + block_timer.labels(self.name).inc(duration) block_ru_utime.labels(self.name).inc(usage.ru_utime) block_ru_stime.labels(self.name).inc(usage.ru_stime) block_db_txn_count.labels(self.name).inc(usage.db_txn_count) block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warn( - "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current - ) - - if self.created_context: - self.start_context.__exit__(exc_type, exc_val, exc_tb) + logger.warning("Failed to save metrics! Usage: %s", usage) def _update_in_flight(self, metrics): """Gets called when processing in flight metrics diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 522acd5aa8..bb62db4637 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -14,12 +14,13 @@ # limitations under the License. import importlib +import importlib.util from synapse.config._base import ConfigError def load_module(provider): - """ Loads a module with its config + """ Loads a synapse module with its config Take a dict with keys 'module' (the module name) and 'config' (the config dict). @@ -33,8 +34,25 @@ def load_module(provider): provider_class = getattr(module, clz) try: - provider_config = provider_class.parse_config(provider["config"]) + provider_config = provider_class.parse_config(provider.get("config")) except Exception as e: raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) return provider_class, provider_config + + +def load_python_module(location: str): + """Load a python module, and return a reference to its global namespace + + Args: + location (str): path to the module + + Returns: + python module object + """ + spec = importlib.util.spec_from_file_location(location, location) + if spec is None: + raise Exception("Unable to load module at %s" % (location,)) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) # type: ignore + return mod diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py new file mode 100644 index 0000000000..2605f3c65b --- /dev/null +++ b/synapse/util/patch_inline_callbacks.py @@ -0,0 +1,224 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import functools +import sys +from typing import Any, Callable, List + +from twisted.internet import defer +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +# Tracks if we've already patched inlineCallbacks +_already_patched = False + + +def do_patch(): + """ + Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit + """ + + from synapse.logging.context import current_context + + global _already_patched + + orig_inline_callbacks = defer.inlineCallbacks + if _already_patched: + return + + def new_inline_callbacks(f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + start_context = current_context() + changes = [] # type: List[str] + orig = orig_inline_callbacks(_check_yield_points(f, changes)) + + try: + res = orig(*args, **kwargs) + except Exception: + if current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + + err = "%s changed context from %s to %s on exception" % ( + f, + start_context, + current_context(), + ) + print(err, file=sys.stderr) + raise Exception(err) + raise + + if not isinstance(res, Deferred) or res.called: + if current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + + err = "Completed %s changed context from %s to %s" % ( + f, + start_context, + current_context(), + ) + # print the error to stderr because otherwise all we + # see in travis-ci is the 500 error + print(err, file=sys.stderr) + raise Exception(err) + return res + + if current_context(): + err = ( + "%s returned incomplete deferred in non-sentinel context " + "%s (start was %s)" + ) % (f, current_context(), start_context) + print(err, file=sys.stderr) + raise Exception(err) + + def check_ctx(r): + if current_context() != start_context: + for err in changes: + print(err, file=sys.stderr) + err = "%s completion of %s changed context from %s to %s" % ( + "Failure" if isinstance(r, Failure) else "Success", + f, + start_context, + current_context(), + ) + print(err, file=sys.stderr) + raise Exception(err) + return r + + res.addBoth(check_ctx) + return res + + return wrapped + + defer.inlineCallbacks = new_inline_callbacks + _already_patched = True + + +def _check_yield_points(f: Callable, changes: List[str]): + """Wraps a generator that is about to be passed to defer.inlineCallbacks + checking that after every yield the log contexts are correct. + + It's perfectly valid for log contexts to change within a function, e.g. due + to new Measure blocks, so such changes are added to the given `changes` + list instead of triggering an exception. + + Args: + f: generator function to wrap + changes: A list of strings detailing how the contexts + changed within a function. + + Returns: + function + """ + + from synapse.logging.context import current_context + + @functools.wraps(f) + def check_yield_points_inner(*args, **kwargs): + gen = f(*args, **kwargs) + + last_yield_line_no = gen.gi_frame.f_lineno + result = None # type: Any + while True: + expected_context = current_context() + + try: + isFailure = isinstance(result, Failure) + if isFailure: + d = result.throwExceptionIntoGenerator(gen) + else: + d = gen.send(result) + except (StopIteration, defer._DefGen_Return) as e: + if current_context() != expected_context: + # This happens when the context is lost sometime *after* the + # final yield and returning. E.g. we forgot to yield on a + # function that returns a deferred. + # + # We don't raise here as it's perfectly valid for contexts to + # change in a function, as long as it sets the correct context + # on resolving (which is checked separately). + err = ( + "Function %r returned and changed context from %s to %s," + " in %s between %d and end of func" + % ( + f.__qualname__, + expected_context, + current_context(), + f.__code__.co_filename, + last_yield_line_no, + ) + ) + changes.append(err) + return getattr(e, "value", None) + + frame = gen.gi_frame + + if isinstance(d, defer.Deferred) and not d.called: + # This happens if we yield on a deferred that doesn't follow + # the log context rules without wrapping in a `make_deferred_yieldable`. + # We raise here as this should never happen. + if current_context(): + err = ( + "%s yielded with context %s rather than sentinel," + " yielded on line %d in %s" + % ( + frame.f_code.co_name, + current_context(), + frame.f_lineno, + frame.f_code.co_filename, + ) + ) + raise Exception(err) + + # the wrapped function yielded a Deferred: yield it back up to the parent + # inlineCallbacks(). + try: + result = yield d + except Exception: + # this will fish an earlier Failure out of the stack where possible, and + # thus is preferable to passing in an exeception to the Failure + # constructor, since it results in less stack-mangling. + result = Failure() + + if current_context() != expected_context: + + # This happens because the context is lost sometime *after* the + # previous yield and *after* the current yield. E.g. the + # deferred we waited on didn't follow the rules, or we forgot to + # yield on a function between the two yield points. + # + # We don't raise here as its perfectly valid for contexts to + # change in a function, as long as it sets the correct context + # on resolving (which is checked separately). + err = ( + "%s changed context from %s to %s, happened between lines %d and %d in %s" + % ( + frame.f_code.co_name, + expected_context, + current_context(), + last_yield_line_no, + frame.f_lineno, + frame.f_code.co_filename, + ) + ) + changes.append(err) + + last_yield_line_no = frame.f_lineno + + return check_yield_points_inner diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 5ca4521ce3..e5efdfcd02 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -43,7 +43,7 @@ class FederationRateLimiter(object): self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): - """Used to ratelimit an incoming request from given host + """Used to ratelimit an incoming request from a given host Example usage: diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index a5f2fbef5c..af69587196 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -29,7 +29,7 @@ MIN_RETRY_INTERVAL = 10 * 60 * 1000 RETRY_MULTIPLIER = 5 # a cap on the backoff. (Essentially none) -MAX_RETRY_INTERVAL = 2 ** 63 +MAX_RETRY_INTERVAL = 2 ** 62 class NotRetryingDestination(Exception): diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index 6c0f2bb0cf..207cd17c2a 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no): resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) ) except (ValueError, resource.error) as e: - logger.warn("Failed to set file or core limit: %s", e) + logger.warning("Failed to set file or core limit: %s", e) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 982c6d81ca..08c86e92b8 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools import random +import re import string +from collections import Iterable -import six -from six import PY2, PY3 -from six.moves import range +from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" +# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken +# Note: The : character is allowed here for older clients, but will be removed in a +# future release. Context: https://github.com/matrix-org/synapse/issues/6766 +client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -37,75 +43,37 @@ def random_string_with_symbols(length): def is_ascii(s): - - if PY3: - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True - - try: - s.encode("ascii") - except UnicodeEncodeError: - return False - except UnicodeDecodeError: - return False - else: + if isinstance(s, bytes): + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False return True -def to_ascii(s): - """Converts a string to ascii if it is ascii, otherwise leave it alone. - - If given None then will return None. - """ - if PY3: - return s - - if s is None: - return None +def assert_valid_client_secret(client_secret): + """Validate that a given string matches the client_secret regex defined by the spec""" + if client_secret_regex.match(client_secret) is None: + raise SynapseError( + 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM + ) - try: - return s.encode("ascii") - except UnicodeEncodeError: - return s +def shortstr(iterable: Iterable, maxitems: int = 5) -> str: + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. -def exception_to_unicode(e): - """Helper function to extract the text of an exception as a unicode string + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". Args: - e (Exception): exception to be stringified - - Returns: - unicode + iterable: iterable to truncate + maxitems: number of items to return before truncating """ - # urgh, this is a mess. The basic problem here is that psycopg2 constructs its - # exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will - # then produce the raw byte sequence. Under Python 2, this will then cause another - # error if it gets mixed with a `unicode` object, as per - # https://github.com/matrix-org/synapse/issues/4252 - - # First of all, if we're under python3, everything is fine because it will sort this - # nonsense out for us. - if not PY2: - return str(e) - - # otherwise let's have a stab at decoding the exception message. We'll circumvent - # Exception.__str__(), which would explode if someone raised Exception(u'non-ascii') - # and instead look at what is in the args member. - - if len(e.args) == 0: - return "" - elif len(e.args) > 1: - return six.text_type(repr(e.args)) - - msg = e.args[0] - if isinstance(msg, bytes): - return msg.decode("utf-8", errors="replace") - else: - return msg + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index fa404b9d75..ab7d03af3a 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -42,6 +42,7 @@ def get_version_string(module): try: null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: git_branch = ( subprocess.check_output( @@ -51,7 +52,8 @@ def get_version_string(module): .decode("ascii") ) git_branch = "b=" + git_branch - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError can arise when git is not installed git_branch = "" try: @@ -63,7 +65,7 @@ def get_version_string(module): .decode("ascii") ) git_tag = "t=" + git_tag - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_tag = "" try: @@ -74,7 +76,7 @@ def get_version_string(module): .strip() .decode("ascii") ) - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_commit = "" try: @@ -89,7 +91,7 @@ def get_version_string(module): ) git_dirty = "dirty" if is_dirty else "" - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_dirty = "" if git_branch or git_tag or git_commit or git_dirty: diff --git a/synapse/visibility.py b/synapse/visibility.py index bf0f1eebd8..bab41182b9 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage import Storage from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id @@ -43,14 +44,19 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - store, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, + user_id, + events, + is_peeking=False, + always_include_ids=frozenset(), + filter_send_to_client=True, ): """ - Check which events a user is allowed to see + Check which events a user is allowed to see. If the user can see the event but its + sender asked for their data to be erased, prune the content of the event. Args: - store (synapse.storage.DataStore): our datastore (can also be a worker - store) + storage user_id(str): user id to be checked events(list[synapse.events.EventBase]): sequence of events to be checked is_peeking(bool): should be True if: @@ -59,21 +65,24 @@ def filter_events_for_client( events always_include_ids (set(event_id)): set of event ids to specifically include (unless sender is ignored) + filter_send_to_client (bool): Whether we're checking an event that's going to be + sent to a client. This might not always be the case since this function can + also be called to check whether a user can see the state at a given point. Returns: Deferred[list[synapse.events.EventBase]] """ # Filter out events that have been soft failed so that we don't relay them # to clients. - events = list(e for e in events if not e.internal_metadata.is_soft_failed()) + events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield store.get_state_for_events( + event_id_to_state = yield storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield store.get_global_account_data_by_type_for_user( + ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -84,7 +93,16 @@ def filter_events_for_client( else [] ) - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + + if filter_send_to_client: + room_ids = {e.room_id for e in events} + retention_policies = {} + + for room_id in room_ids: + retention_policies[ + room_id + ] = yield storage.main.get_retention_policy_for_room(room_id) def allowed(event): """ @@ -100,8 +118,36 @@ def filter_events_for_client( the original event if they can see it as normal. """ - if not event.is_state() and event.sender in ignore_list: - return None + # Only run some checks if these events aren't about to be sent to clients. This is + # because, if this is not the case, we're probably only checking if the users can + # see events in the room at that point in the DAG, and that shouldn't be decided + # on those checks. + if filter_send_to_client: + if event.type == "org.matrix.dummy_event": + return None + + if not event.is_state() and event.sender in ignore_list: + return None + + # Until MSC2261 has landed we can't redact malicious alias events, so for + # now we temporarily filter out m.room.aliases entirely to mitigate + # abuse, while we spec a better solution to advertising aliases + # on rooms. + if event.type == EventTypes.Aliases: + return None + + # Don't try to apply the room's retention policy if the event is a state + # event, as MSC1763 states that retention is only considered for non-state + # events. + if not event.is_state(): + retention_policy = retention_policies[event.room_id] + max_lifetime = retention_policy.get("max_lifetime") + + if max_lifetime is not None: + oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + + if event.origin_server_ts < oldest_allowed_ts: + return None if event.event_id in always_include_ids: return event @@ -213,13 +259,17 @@ def filter_events_for_client( @defer.inlineCallbacks def filter_events_for_server( - store, server_name, events, redact=True, check_history_visibility_only=False + storage: Storage, + server_name, + events, + redact=True, + check_history_visibility_only=False, ): """Filter a list of events based on whether given server is allowed to see them. Args: - store (DataStore) + storage server_name (str) events (iterable[FrozenEvent]) redact (bool): Whether to return a redacted version of the event, or @@ -274,7 +324,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -292,14 +342,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield store.get_events(visibility_ids) + event_map = yield storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in itervalues(event_map) ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -328,7 +378,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -358,7 +408,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield store.get_events( + event_map = yield storage.main.get_events( [ e_id for e_id, key in iteritems(event_id_to_state_key) diff --git a/synctl b/synctl index 45acece30b..960fd357ee 100755 --- a/synctl +++ b/synctl @@ -117,7 +117,17 @@ def start_worker(app: str, configfile: str, worker_configfile: str) -> bool: False if there was an error starting the process """ - args = [sys.executable, "-B", "-m", app, "-c", configfile, "-c", worker_configfile] + args = [ + sys.executable, + "-B", + "-m", + app, + "-c", + configfile, + "-c", + worker_configfile, + "--daemonize", + ] try: subprocess.check_call(args) @@ -132,12 +142,23 @@ def start_worker(app: str, configfile: str, worker_configfile: str) -> bool: return False -def stop(pidfile, app): +def stop(pidfile: str, app: str) -> bool: + """Attempts to kill a synapse worker from the pidfile. + Args: + pidfile: path to file containing worker's pid + app: name of the worker's appservice + + Returns: + True if the process stopped successfully + False if process was already stopped or an error occured + """ + if os.path.exists(pidfile): pid = int(open(pidfile).read()) try: os.kill(pid, signal.SIGTERM) write("stopped %s" % (app,), colour=GREEN) + return True except OSError as err: if err.errno == errno.ESRCH: write("%s not running" % (app,), colour=YELLOW) @@ -145,6 +166,14 @@ def stop(pidfile, app): abort("Cannot stop %s: Operation not permitted" % (app,)) else: abort("Cannot stop %s: Unknown error" % (app,)) + return False + else: + write( + "No running worker of %s found (from %s)\nThe process might be managed by another controller (e.g. systemd)" + % (app, pidfile), + colour=YELLOW, + ) + return False Worker = collections.namedtuple( @@ -266,9 +295,6 @@ def main(): worker_cache_factors = ( worker_config.get("synctl_cache_factors") or cache_factors ) - daemonize = worker_config.get("daemonize") or config.get("daemonize") - assert daemonize, "Main process must have daemonize set to true" - # The master process doesn't support using worker_* config. for key in worker_config: if key == "worker_app": # But we allow worker_app @@ -278,11 +304,6 @@ def main(): ), "Main process cannot use worker_* config" else: worker_pidfile = worker_config["worker_pid_file"] - worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize, "In config %r: expected '%s' to be True" % ( - worker_configfile, - "worker_daemonize", - ) worker_cache_factor = worker_config.get("synctl_cache_factor") worker_cache_factors = worker_config.get("synctl_cache_factors", {}) workers.append( @@ -298,11 +319,17 @@ def main(): action = options.action if action == "stop" or action == "restart": + has_stopped = True for worker in workers: - stop(worker.pidfile, worker.app) + if not stop(worker.pidfile, worker.app): + # A worker could not be stopped. + has_stopped = False if start_stop_synapse: - stop(pidfile, "synapse.app.homeserver") + if not stop(pidfile, "synapse.app.homeserver"): + has_stopped = False + if not has_stopped and action == "stop": + sys.exit(1) # Wait for synapse to actually shutdown before starting it again if action == "restart": diff --git a/synmark/__init__.py b/synmark/__init__.py new file mode 100644 index 0000000000..afe4fad8cb --- /dev/null +++ b/synmark/__init__.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from twisted.internet import epollreactor +from twisted.internet.main import installReactor + +from synapse.config.homeserver import HomeServerConfig +from synapse.util import Clock + +from tests.utils import default_config, setup_test_homeserver + + +async def make_homeserver(reactor, config=None): + """ + Make a Homeserver suitable for running benchmarks against. + + Args: + reactor: A Twisted reactor to run under. + config: A HomeServerConfig to use, or None. + """ + cleanup_tasks = [] + clock = Clock(reactor) + + if not config: + config = default_config("test") + + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config, "", "") + + hs = await setup_test_homeserver( + cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock + ) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor.db.updates, "do_next_background_update"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + + def cleanup(): + for i in cleanup_tasks: + i() + + return hs, clock.sleep, cleanup + + +def make_reactor(): + """ + Instantiate and install a Twisted reactor suitable for testing (i.e. not the + default global one). + """ + reactor = epollreactor.EPollReactor() + + if "twisted.internet.reactor" in sys.modules: + del sys.modules["twisted.internet.reactor"] + installReactor(reactor) + + return reactor diff --git a/synmark/__main__.py b/synmark/__main__.py new file mode 100644 index 0000000000..17df9ddeb7 --- /dev/null +++ b/synmark/__main__.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from argparse import REMAINDER +from contextlib import redirect_stderr +from io import StringIO + +import pyperf +from synmark import make_reactor +from synmark.suites import SUITES + +from twisted.internet.defer import Deferred, ensureDeferred +from twisted.logger import globalLogBeginner, textFileLogObserver +from twisted.python.failure import Failure + +from tests.utils import setupdb + + +def make_test(main): + """ + Take a benchmark function and wrap it in a reactor start and stop. + """ + + def _main(loops): + + reactor = make_reactor() + + file_out = StringIO() + with redirect_stderr(file_out): + + d = Deferred() + d.addCallback(lambda _: ensureDeferred(main(reactor, loops))) + + def on_done(_): + if isinstance(_, Failure): + _.printTraceback() + print(file_out.getvalue()) + reactor.stop() + return _ + + d.addBoth(on_done) + reactor.callWhenRunning(lambda: d.callback(True)) + reactor.run() + + return d.result + + return _main + + +if __name__ == "__main__": + + def add_cmdline_args(cmd, args): + if args.log: + cmd.extend(["--log"]) + cmd.extend(args.tests) + + runner = pyperf.Runner( + processes=3, min_time=1.5, show_name=True, add_cmdline_args=add_cmdline_args + ) + runner.argparser.add_argument("--log", action="store_true") + runner.argparser.add_argument("tests", nargs=REMAINDER) + runner.parse_args() + + orig_loops = runner.args.loops + runner.args.inherit_environ = ["SYNAPSE_POSTGRES"] + + if runner.args.worker: + if runner.args.log: + globalLogBeginner.beginLoggingTo( + [textFileLogObserver(sys.__stdout__)], redirectStandardIO=False + ) + setupdb() + + if runner.args.tests: + SUITES = list( + filter(lambda x: x[0].__name__.split(".")[-1] in runner.args.tests, SUITES) + ) + + for suite, loops in SUITES: + if loops: + runner.args.loops = loops + else: + runner.args.loops = orig_loops + loops = "auto" + runner.bench_time_func( + suite.__name__ + "_" + str(loops), make_test(suite.main), + ) diff --git a/synmark/suites/__init__.py b/synmark/suites/__init__.py new file mode 100644 index 0000000000..d8445fc3df --- /dev/null +++ b/synmark/suites/__init__.py @@ -0,0 +1,9 @@ +from . import logging, lrucache, lrucache_evict + +SUITES = [ + (logging, 1000), + (logging, 10000), + (logging, None), + (lrucache, None), + (lrucache_evict, None), +] diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py new file mode 100644 index 0000000000..d8e4c7d58f --- /dev/null +++ b/synmark/suites/logging.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from io import StringIO + +from mock import Mock + +from pyperf import perf_counter +from synmark import make_homeserver + +from twisted.internet.defer import Deferred +from twisted.internet.protocol import ServerFactory +from twisted.logger import LogBeginner, Logger, LogPublisher +from twisted.protocols.basic import LineOnlyReceiver + +from synapse.logging._structured import setup_structured_logging + + +class LineCounter(LineOnlyReceiver): + + delimiter = b"\n" + + def __init__(self, *args, **kwargs): + self.count = 0 + super().__init__(*args, **kwargs) + + def lineReceived(self, line): + self.count += 1 + + if self.count >= self.factory.wait_for and self.factory.on_done: + on_done = self.factory.on_done + self.factory.on_done = None + on_done.callback(True) + + +async def main(reactor, loops): + """ + Benchmark how long it takes to send `loops` messages. + """ + servers = [] + + def protocol(): + p = LineCounter() + servers.append(p) + return p + + logger_factory = ServerFactory.forProtocol(protocol) + logger_factory.wait_for = loops + logger_factory.on_done = Deferred() + port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1") + + hs, wait, cleanup = await make_homeserver(reactor) + + errors = StringIO() + publisher = LogPublisher() + mock_sys = Mock() + beginner = LogBeginner( + publisher, errors, mock_sys, warnings, initialBufferSize=loops + ) + + log_config = { + "loggers": {"synapse": {"level": "DEBUG"}}, + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": port.getHost().port, + "maximum_buffer": 100, + } + }, + } + + logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher) + logging_system = setup_structured_logging( + hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False + ) + + # Wait for it to connect... + await logging_system._observers[0]._service.whenConnected() + + start = perf_counter() + + # Send a bunch of useful messages + for i in range(0, loops): + logger.info("test message %s" % (i,)) + + if ( + len(logging_system._observers[0]._buffer) + == logging_system._observers[0].maximum_buffer + ): + while ( + len(logging_system._observers[0]._buffer) + > logging_system._observers[0].maximum_buffer / 2 + ): + await wait(0.01) + + await logger_factory.on_done + + end = perf_counter() - start + + logging_system.stop() + port.stopListening() + cleanup() + + return end diff --git a/synmark/suites/lrucache.py b/synmark/suites/lrucache.py new file mode 100644 index 0000000000..69ab042ccc --- /dev/null +++ b/synmark/suites/lrucache.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyperf import perf_counter + +from synapse.util.caches.lrucache import LruCache + + +async def main(reactor, loops): + """ + Benchmark `loops` number of insertions into LruCache without eviction. + """ + cache = LruCache(loops) + + start = perf_counter() + + for i in range(loops): + cache[i] = True + + end = perf_counter() - start + + return end diff --git a/synmark/suites/lrucache_evict.py b/synmark/suites/lrucache_evict.py new file mode 100644 index 0000000000..532b1cc702 --- /dev/null +++ b/synmark/suites/lrucache_evict.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyperf import perf_counter + +from synapse.util.caches.lrucache import LruCache + + +async def main(reactor, loops): + """ + Benchmark `loops` number of insertions into LruCache where half of them are + evicted. + """ + cache = LruCache(loops // 2) + + start = perf_counter() + + for i in range(loops): + cache[i] = True + + end = perf_counter() - start + + return end diff --git a/sytest-blacklist b/sytest-blacklist index 11785fd43f..79b2d4402a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,6 +1,6 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in # Synapse. -# +# # Each line of this file is scanned by sytest during a run and if the line # exactly matches the name of a test, it will be marked as "expected fail", # meaning the test will still run, but failure will not mark the entire test @@ -29,3 +29,10 @@ Enabling an unknown default rule fails with 404 # Blacklisted due to https://github.com/matrix-org/synapse/issues/1663 New federated private chats get full presence information (SYN-115) + +# Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing +# this requirement from the spec +Inbound federation of state requires event_id as a mandatory paramater + +# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands +Can upload self-signing keys diff --git a/tests/__init__.py b/tests/__init__.py index f7fc502f01..ed805db1c2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,9 +16,9 @@ from twisted.trial import util -import tests.patch_inline_callbacks +from synapse.util.patch_inline_callbacks import do_patch # attempt to do the patch before we load any synapse code -tests.patch_inline_callbacks.do_patch() +do_patch() util.DEFAULT_TIMEOUT_DURATION = 20 diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 6121efcfa9..0bfb86bf1f 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = TestHandlers(self.hs) self.auth = Auth(self.hs) + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.auth._auth_blocking + self.test_user = "@foo:bar" self.test_token = b"_test_token_" @@ -68,7 +72,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -105,7 +109,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) @defer.inlineCallbacks @@ -125,7 +129,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): @@ -188,7 +192,7 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -225,7 +229,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(macaroon.serialize()) + ) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -250,7 +256,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth.get_user_by_access_token(serialized) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(serialized) + ) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) @@ -260,10 +268,13 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cannot_use_regular_token_as_guest(self): USER_ID = "@percy:matrix.org" - self.store.add_access_token_to_user = Mock() + self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) + self.store.get_device = Mock(return_value=defer.succeed(None)) - token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( - USER_ID, "DEVICE", valid_until_ms=None + token = yield defer.ensureDeferred( + self.hs.handlers.auth_handler.get_access_token_for_user_id( + USER_ID, "DEVICE", valid_until_ms=None + ) ) self.store.add_access_token_to_user.assert_called_with( USER_ID, token, "DEVICE", None @@ -286,7 +297,9 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertFalse(requester.is_guest) @@ -301,7 +314,9 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(InvalidClientCredentialsError) as cm: - yield self.auth.get_user_by_req(request, allow_guest=True) + yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(401, cm.exception.code) self.assertEqual("Guest access token used for regular user", cm.exception.msg) @@ -310,22 +325,22 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_blocking_mau(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = False + self.auth_blocking._max_mau_value = 50 lots_of_users = 100 small_number_of_users = 1 # Ensure no error thrown - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) ) with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -334,49 +349,54 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(small_number_of_users) ) - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): - self.hs.config.max_mau_value = 50 - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 50 + self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Support users allowed - yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Bots not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(user_type=UserTypes.BOT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.BOT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Real users not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_reserved_threepid(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 self.store.get_monthly_active_count = lambda: defer.succeed(2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} - self.hs.config.mau_limits_reserved_threepids = [threepid] + self.auth_blocking._mau_limits_reserved_threepids = [threepid] - yield self.store.register_user(user_id="user1", password_hash=None) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(threepid=unknown_threepid) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(threepid=unknown_threepid) + ) - yield self.auth.check_auth_blocking(threepid=threepid) + yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) @defer.inlineCallbacks def test_hs_disabled(self): - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -388,20 +408,20 @@ class AuthTestCase(unittest.TestCase): """ # this should be the default, but we had a bug where the test was doing the wrong # thing, so let's make it explicit - self.hs.config.server_notices_mxid = None + self.auth_blocking._server_notices_mxid = None - self.hs.config.hs_disabled = True - self.hs.config.hs_disabled_message = "Reason for being disabled" + self.auth_blocking._hs_disabled = True + self.auth_blocking._hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @defer.inlineCallbacks def test_server_notices_mxid_special_cased(self): - self.hs.config.hs_disabled = True + self.auth_blocking._hs_disabled = True user = "@user:server" - self.hs.config.server_notices_mxid = user - self.hs.config.hs_disabled_message = "Reason for being disabled" - yield self.auth.check_auth_blocking(user) + self.auth_blocking._server_notices_mxid = user + self.auth_blocking._hs_disabled_message = "Reason for being disabled" + yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 6ba623de13..4e67503cf0 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +22,10 @@ import jsonschema from twisted.internet import defer +from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from tests import unittest from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver @@ -34,7 +38,7 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" - return FrozenEvent(kwargs) + return make_event_from_dict(kwargs) class FilteringTestCase(unittest.TestCase): @@ -95,6 +99,8 @@ class FilteringTestCase(unittest.TestCase): "types": ["m.room.message"], "not_rooms": ["!726s6s6q:example.com"], "not_senders": ["@spam:example.com"], + "org.matrix.labels": ["#fun"], + "org.matrix.not_labels": ["#work"], }, "ephemeral": { "types": ["m.receipt", "m.typing"], @@ -320,6 +326,46 @@ class FilteringTestCase(unittest.TestCase): ) self.assertFalse(Filter(definition).check(event)) + def test_filter_labels(self): + definition = {"org.matrix.labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#fun"]}, + ) + + self.assertTrue(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#notfun"]}, + ) + + self.assertFalse(Filter(definition).check(event)) + + def test_filter_not_labels(self): + definition = {"org.matrix.not_labels": ["#fun"]} + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#fun"]}, + ) + + self.assertFalse(Filter(definition).check(event)) + + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown", + content={EventContentFields.LABELS: ["#notfun"]}, + ) + + self.assertTrue(Filter(definition).check(event)) + @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dbdd427cac..d580e729c5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,39 +1,97 @@ -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from tests import unittest class TestRatelimiter(unittest.TestCase): - def test_allowed(self): - limiter = Ratelimiter() - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 - ) + def test_allowed_via_can_do_action(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) - def test_pruning(self): - limiter = Ratelimiter() + def test_allowed_via_ratelimit(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=0) + + # Should raise + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key="test_id", _time_now_s=5) + self.assertEqual(context.exception.retry_after_ms, 5000) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=10) + + def test_allowed_via_can_do_action_and_overriding_parameters(self): + """Test that we can override options of can_do_action that would otherwise fail + an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,) + self.assertTrue(allowed) + self.assertEqual(10.0, time_allowed) + + # Second attempt, 1s later, will fail + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,) + self.assertFalse(allowed) + self.assertEqual(10.0, time_allowed) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. allowed, time_allowed = limiter.can_do_action( - key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, rate_hz=10.0 ) + self.assertTrue(allowed) + self.assertEqual(1.1, time_allowed) - self.assertIn("test_id_1", limiter.message_counts) - + # Similarly if we allow a burst of 10 actions allowed, time_allowed = limiter.can_do_action( - key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, burst_count=10 ) + self.assertTrue(allowed) + self.assertEqual(1.0, time_allowed) + + def test_allowed_via_ratelimit_and_overriding_parameters(self): + """Test that we can override options of the ratelimit method that would otherwise + fail an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + limiter.ratelimit(key=("test_id",), _time_now_s=0) + + # Second attempt, 1s later, will fail + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.assertEqual(context.exception.retry_after_ms, 9000) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. + limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + + # Similarly if we allow a burst of 10 actions + limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + + def test_pruning(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter.can_do_action(key="test_id_1", _time_now_s=0) + + self.assertIn("test_id_1", limiter.actions) + + limiter.can_do_action(key="test_id_2", _time_now_s=10) - self.assertNotIn("test_id_1", limiter.message_counts) + self.assertNotIn("test_id_1", limiter.actions) diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 8bdbc608a9..be20a89682 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.app.frontend_proxy import FrontendProxyServer +from synapse.app.generic_worker import GenericWorkerServer from tests.unittest import HomeserverTestCase @@ -22,11 +22,16 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=FrontendProxyServer + http_client=None, homeserverToUse=GenericWorkerServer ) return hs + def default_config(self): + c = super().default_config() + c["worker_app"] = "synapse.app.frontend_proxy" + return c + def test_listen_http_with_presence_enabled(self): """ When presence is on, the stub servlet will not register. @@ -46,9 +51,7 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = ( - site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"client"] request, channel = self.make_request("PUT", "presence/a/status") self.render(request) @@ -76,9 +79,7 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = ( - site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"client"] request, channel = self.make_request("PUT", "presence/a/status") self.render(request) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 48792d1480..7364f9f1ec 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -16,7 +16,7 @@ from mock import Mock, patch from parameterized import parameterized -from synapse.app.federation_reader import FederationReaderServer +from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer from tests.unittest import HomeserverTestCase @@ -25,10 +25,18 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=FederationReaderServer + http_client=None, homeserverToUse=GenericWorkerServer ) return hs + def default_config(self): + conf = super().default_config() + # we're using FederationReaderServer, which uses a SlavedStore, so we + # have to tell the FederationHandler not to try to access stuff that is only + # in the primary store. + conf["worker_app"] = "yes" + return conf + @parameterized.expand( [ (["federation"], "auth_fail"), diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py new file mode 100644 index 0000000000..d3ec24c975 --- /dev/null +++ b/tests/config/test_cache.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, RootConfig +from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.util.caches.lrucache import LruCache + +from tests.unittest import TestCase + + +class FakeServer(Config): + section = "server" + + +class TestConfig(RootConfig): + config_classes = [FakeServer, CacheConfig] + + +class CacheConfigTests(TestCase): + def setUp(self): + # Reset caches before each test + TestConfig().caches.reset() + + def test_individual_caches_from_environ(self): + """ + Individual cache factors will be loaded from the environment. + """ + config = {} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_NOT_CACHE": "BLAH", + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + + def test_config_overrides_environ(self): + """ + Individual cache factors defined in the environment will take precedence + over those in the config. + """ + config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_CACHE_FACTOR_FOO": 1, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual( + dict(t.caches.cache_factors), + {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, + ) + + def test_individual_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized once the config + is loaded. + """ + cache = LruCache(100) + + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"per_cache_factors": {"foo": 3}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 300) + + def test_individual_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the per_cache_factor if + there is one. + """ + config = {"caches": {"per_cache_factors": {"foo": 2}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 200) + + def test_global_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized to the new + default cache size once the config is loaded. + """ + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"global_factor": 4}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 400) + + def test_global_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the global factor if there + is no per-cache factor. + """ + config = {"caches": {"global_factor": 1.5}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 150) + + def test_cache_with_asterisk_in_name(self): + """Some caches have asterisks in their name, test that they are set correctly. + """ + + config = { + "caches": { + "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} + } + } + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", + "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + cache_a = LruCache(100) + add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) + self.assertEqual(cache_a.max_size, 200) + + cache_b = LruCache(100) + add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor) + self.assertEqual(cache_b.max_size, 300) + + cache_c = LruCache(100) + add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) + self.assertEqual(cache_c.max_size, 200) + + def test_apply_cache_factor_from_config(self): + """Caches can disable applying cache factor updates, mainly used by + event cache size. + """ + + config = {"caches": {"event_cache_size": "10k"}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache( + max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False, + ) + add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) + + self.assertEqual(cache.max_size, 10240) diff --git a/tests/config/test_database.py b/tests/config/test_database.py index 151d3006ac..f675bde68e 100644 --- a/tests/config/test_database.py +++ b/tests/config/test_database.py @@ -21,9 +21,9 @@ from tests import unittest class DatabaseConfigTestCase(unittest.TestCase): - def test_database_configured_correctly_no_database_conf_param(self): + def test_database_configured_correctly(self): conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", None) + DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path") ) expected_database_conf = { @@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase): } self.assertEqual(conf["database"], expected_database_conf) - - def test_database_configured_correctly_database_conf_param(self): - - database_conf = { - "name": "my super fast datastore", - "args": { - "user": "matrix", - "password": "synapse_database_password", - "host": "synapse_database_host", - "database": "matrix", - }, - } - - conf = yaml.safe_load( - DatabaseConfig().generate_config_section("/data_dir_path", database_conf) - ) - - self.assertEqual(conf["database"], database_conf) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 2684e662de..463855ecc8 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase): ) self.assertSetEqual( - set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), + {"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"}, set(os.listdir(self.dir)), ) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index b3e557bd6a..734a9983e8 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase): with open(self.file, "r") as f: contents = f.readlines() - contents = [l for l in contents if needle not in l] + contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: f.write("".join(contents)) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index b02780772a..ec32d4b1ca 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -21,17 +21,24 @@ import yaml from OpenSSL import SSL +from synapse.config._base import Config, RootConfig from synapse.config.tls import ConfigError, TlsConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from tests.unittest import TestCase -class TestConfig(TlsConfig): +class FakeServer(Config): + section = "server" + def has_tls_listener(self): return False +class TestConfig(RootConfig): + config_classes = [FakeServer, TlsConfig] + + class TLSConfigTests(TestCase): def test_warn_self_signed(self): """ @@ -173,12 +180,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2 - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertNotEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_tls_client_minimum_set_passed_through_1_0(self): """ @@ -188,12 +196,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) + options = _get_ssl_context_options(cf._verify_ssl_context) # The context has not had any of the NO_TLS set. - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_1, 0) - self.assertEqual(cf._verify_ssl._options & SSL.OP_NO_TLSv1_2, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0) + self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0) def test_acme_disabled_in_generated_config_no_acme_domain_provied(self): """ @@ -202,13 +211,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= conf = TestConfig() conf.read_config( yaml.safe_load( - TestConfig().generate_config_section( + TestConfig().generate_config( "/config_dir_path", "my_super_secure_server", "/data_dir_path", - "/tls_cert_path", - "tls_private_key", - None, # This is the acme_domain + tls_certificate_path="/tls_cert_path", + tls_private_key_path="tls_private_key", + acme_domain=None, # This is the acme_domain ) ), "/config_dir_path", @@ -223,13 +232,13 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= conf = TestConfig() conf.read_config( yaml.safe_load( - TestConfig().generate_config_section( + TestConfig().generate_config( "/config_dir_path", "my_super_secure_server", "/data_dir_path", - "/tls_cert_path", - "tls_private_key", - "my_supe_secure_server", # This is the acme_domain + tls_certificate_path="/tls_cert_path", + tls_private_key_path="tls_private_key", + acme_domain="my_supe_secure_server", # This is the acme_domain ) ), "/config_dir_path", @@ -266,7 +275,7 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - cf = ClientTLSOptionsFactory(t) + cf = FederationPolicyForHTTPS(t) # Not in the whitelist opts = cf.get_options(b"notexample.com") @@ -275,3 +284,10 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= # Caught by the wildcard opts = cf.get_options(idna.encode("テスト.ドメイン.テスト")) self.assertFalse(opts._verifier._verify_certs) + + +def _get_ssl_context_options(ssl_context: SSL.Context) -> int: + """get the options bits from an openssl context object""" + # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to + # use the low-level interface + return SSL._lib.SSL_CTX_get_options(ssl_context._context) diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 126e176004..62f639a18d 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -17,8 +17,9 @@ import nacl.signing from unpaddedbase64 import decode_base64 +from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from tests import unittest @@ -49,9 +50,11 @@ class EventSigningTestCase(unittest.TestCase): "unsigned": {"age_ts": 1000000}, } - add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) + add_hashes_and_signatures( + RoomVersions.V1, event_dict, HOSTNAME, self.signing_key + ) - event = FrozenEvent(event_dict) + event = make_event_from_dict(event_dict) self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) @@ -81,9 +84,11 @@ class EventSigningTestCase(unittest.TestCase): "unsigned": {"age_ts": 1000000}, } - add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) + add_hashes_and_signatures( + RoomVersions.V1, event_dict, HOSTNAME, self.signing_key + ) - event = FrozenEvent(event_dict) + event = make_event_from_dict(event_dict) self.assertTrue(hasattr(event, "hashes")) self.assertIn("sha256", event.hashes) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index c4f0bbd3dd..70c8e72303 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,6 +19,7 @@ from mock import Mock import canonicaljson import signedjson.key import signedjson.sign +from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer @@ -33,6 +34,7 @@ from synapse.crypto.keyring import ( from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.storage.keys import FetchKeyResult @@ -82,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) def check_context(self, _, expected): - self.assertEquals( - getattr(LoggingContext.current_context(), "request", None), expected - ) + self.assertEquals(getattr(current_context(), "request", None), expected) def test_verify_json_objects_for_server_awaits_previous_requests(self): key1 = signedjson.key.generate_signing_key(1) @@ -104,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): - self.assertEquals(LoggingContext.current_context().request, "11") + self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): yield persp_deferred return persp_resp @@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], @@ -412,34 +412,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): handlers=None, http_client=self.http_client, config=config ) - def test_get_keys_from_perspectives(self): - # arbitrarily advance the clock a bit - self.reactor.advance(100) - - fetcher = PerspectivesKeyFetcher(self.hs) - - SERVER_NAME = "server2" - testkey = signedjson.key.generate_signing_key("ver1") - testverifykey = signedjson.key.get_verify_key(testkey) - testverifykey_id = "ed25519:ver1" - VALID_UNTIL_TS = 200 * 1000 + def build_perspectives_response( + self, server_name: str, signing_key: SigningKey, valid_until_ts: int, + ) -> dict: + """ + Build a valid perspectives server response to a request for the given key + """ + verify_key = signedjson.key.get_verify_key(signing_key) + verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version) - # valid response response = { - "server_name": SERVER_NAME, + "server_name": server_name, "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, + "valid_until_ts": valid_until_ts, "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) + verifykey_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) } }, } - # the response must be signed by both the origin server and the perspectives # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) + signedjson.sign.sign_json(response, server_name, signing_key) self.mock_perspective_server.sign_response(response) + return response + + def expect_outgoing_key_query( + self, expected_server_name: str, expected_key_id: str, response: dict + ) -> None: + """ + Tell the mock http client to expect a perspectives-server key query + """ def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) @@ -447,11 +450,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the request is for the expected key q = data["server_keys"] - self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"]) + self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id]) return {"server_keys": [response]} self.http_client.post_json.side_effect = post_json + def test_get_keys_from_perspectives(self): + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = "server2" + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS, + ) + + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) + + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + self.assertIn(SERVER_NAME, keys) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") + + # check that the perspectives store is correctly updated + lookup_triplet = (SERVER_NAME, testverifykey_id, None) + key_json = self.get_success( + self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + ) + res = key_json[lookup_triplet] + self.assertEqual(len(res), 1) + res = res[0] + self.assertEqual(res["key_id"], testverifykey_id) + self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) + self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) + self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + + self.assertEqual( + bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) + ) + + def test_get_perspectives_own_key(self): + """Check that we can get the perspectives server's own keys + + This is slightly complicated by the fact that the perspectives server may + use different keys for signing notary responses. + """ + + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = self.mock_perspective_server.server_name + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) + + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} keys = self.get_success(fetcher.get_keys(keys_to_fetch)) self.assertIn(SERVER_NAME, keys) @@ -490,35 +561,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): VALID_UNTIL_TS = 200 * 1000 def build_response(): - # valid response - response = { - "server_name": SERVER_NAME, - "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, - "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) - } - }, - } - - # the response must be signed by both the origin server and the perspectives - # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) - self.mock_perspective_server.sign_response(response) - return response + return self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) def get_key_from_perspectives(response): fetcher = PerspectivesKeyFetcher(self.hs) keys_to_fetch = {SERVER_NAME: {"key1": 0}} - - def post_json(destination, path, data, **kwargs): - self.assertEqual(destination, self.mock_perspective_server.server_name) - self.assertEqual(path, "/_matrix/key/v2/query") - return {"server_keys": [response]} - - self.http_client.post_json.side_effect = post_json - + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) return self.get_success(fetcher.get_keys(keys_to_fetch)) # start with a valid response so we can check we are testing the right thing diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py new file mode 100644 index 0000000000..640f5f3bce --- /dev/null +++ b/tests/events/test_snapshot.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.events.snapshot import EventContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest +from tests.test_utils.event_injection import create_event + + +class TestEventContext(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.storage = hs.get_storage() + + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + self.room_id = self.helper.create_room_as(tok=self.user_tok) + + def test_serialize_deserialize_msg(self): + """Test that an EventContext for a message event is the same after + serialize/deserialize. + """ + + event, context = create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) + + self._check_serialize_deserialize(event, context) + + def test_serialize_deserialize_state_no_prev(self): + """Test that an EventContext for a state event (with not previous entry) + is the same after serialize/deserialize. + """ + event, context = create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) + + self._check_serialize_deserialize(event, context) + + def test_serialize_deserialize_state_prev(self): + """Test that an EventContext for a state event (which replaces a + previous entry) is the same after serialize/deserialize. + """ + event, context = create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) + + self._check_serialize_deserialize(event, context) + + def _check_serialize_deserialize(self, event, context): + serialized = self.get_success(context.serialize(event, self.store)) + + d_context = EventContext.deserialize(self.storage, serialized) + + self.assertEqual(context.state_group, d_context.state_group) + self.assertEqual(context.rejected, d_context.rejected) + self.assertEqual( + context.state_group_before_event, d_context.state_group_before_event + ) + self.assertEqual(context.prev_group, d_context.prev_group) + self.assertEqual(context.delta_ids, d_context.delta_ids) + self.assertEqual(context.app_service, d_context.app_service) + + self.assertEqual( + self.get_success(context.get_current_state_ids()), + self.get_success(d_context.get_current_state_ids()), + ) + self.assertEqual( + self.get_success(context.get_prev_state_ids()), + self.get_success(d_context.get_prev_state_ids()), + ) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 9e3d4d0f47..c1274c14af 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,11 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.events.utils import ( + copy_power_levels_contents, + prune_event, + serialize_event, +) +from synapse.util.frozenutils import freeze -from synapse.events import FrozenEvent -from synapse.events.utils import prune_event, serialize_event - -from .. import unittest +from tests import unittest def MockEvent(**kwargs): @@ -25,15 +30,17 @@ def MockEvent(**kwargs): kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: kwargs["type"] = "fake_type" - return FrozenEvent(kwargs) + return make_event_from_dict(kwargs) class PruneEventTestCase(unittest.TestCase): """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. """ - def run_test(self, evdict, matchdict): - self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict) + def run_test(self, evdict, matchdict, **kwargs): + self.assertEquals( + prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict + ) def test_minimal(self): self.run_test( @@ -122,6 +129,36 @@ class PruneEventTestCase(unittest.TestCase): }, ) + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + self.run_test( + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + }, + { + "type": "m.room.aliases", + "event_id": "$test:domain", + "content": {"aliases": ["test"]}, + "signatures": {}, + "unsigned": {}, + }, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + self.run_test( + {"type": "m.room.aliases", "content": {"aliases": ["test"]}}, + { + "type": "m.room.aliases", + "content": {}, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.V6, + ) + class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): @@ -241,3 +278,39 @@ class SerializeEventTestCase(unittest.TestCase): self.serialize( MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] ) + + +class CopyPowerLevelsContentTestCase(unittest.TestCase): + def setUp(self) -> None: + self.test_content = { + "ban": 50, + "events": {"m.room.name": 100, "m.room.power_levels": 100}, + "events_default": 0, + "invite": 50, + "kick": 50, + "notifications": {"room": 20}, + "redact": 50, + "state_default": 50, + "users": {"@example:localhost": 100}, + "users_default": 0, + } + + def _test(self, input): + a = copy_power_levels_contents(input) + + self.assertEqual(a["ban"], 50) + self.assertEqual(a["events"]["m.room.name"], 100) + + # make sure that changing the copy changes the copy and not the orig + a["ban"] = 10 + a["events"]["m.room.power_levels"] = 20 + + self.assertEqual(input["ban"], 50) + self.assertEqual(input["events"]["m.room.power_levels"], 100) + + def test_unfrozen(self): + self._test(self.test_content) + + def test_frozen(self): + input = freeze(self.test_content) + self._test(input) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 51714a2b06..0c9987be54 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -18,17 +18,14 @@ from mock import Mock from twisted.internet import defer from synapse.api.errors import Codes, SynapseError -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import UserID -from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest -class RoomComplexityTests(unittest.HomeserverTestCase): +class RoomComplexityTests(unittest.FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -36,30 +33,11 @@ class RoomComplexityTests(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self, name="test"): - config = super().default_config(name=name) + def default_config(self): + config = super().default_config() config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config - def prepare(self, reactor, clock, homeserver): - class Authenticator(object): - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter( - clock, - FederationRateLimitConfig( - window_size=1, - sleep_limit=1, - sleep_msec=1, - reject_limit=1000, - concurrent_requests=1000, - ), - ) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - def test_complexity_simple(self): u1 = self.register_user("u1", "pass") @@ -101,11 +79,13 @@ class RoomComplexityTests(unittest.HomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], "roomid", UserID.from_string(u1), {"membership": "join"}, @@ -137,7 +117,9 @@ class RoomComplexityTests(unittest.HomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) # Artificially raise the complexity self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( @@ -146,7 +128,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], room_1, UserID.from_string(u1), {"membership": "join"}, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cce8d8c6de..ff12539041 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,23 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from mock import Mock +from signedjson import key, sign +from signedjson.types import BaseKey, SigningKey + from twisted.internet import defer -from synapse.types import ReadReceipt +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.types import JsonDict, ReadReceipt -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config -class FederationSenderTestCases(HomeserverTestCase): +class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - return super(FederationSenderTestCases, self).setup_test_homeserver( + return self.setup_test_homeserver( state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) + @override_config({"send_federation": True}) def test_send_receipts(self): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] @@ -69,6 +76,7 @@ class FederationSenderTestCases(HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" @@ -145,3 +153,392 @@ class FederationSenderTestCases(HomeserverTestCase): } ], ) + + +class FederationSenderDevicesTestCases(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + state_handler=Mock(spec=["get_current_hosts_in_room"]), + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def default_config(self): + c = super().default_config() + c["send_federation"] = True + return c + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + mock_state_handler = hs.get_state_handler() + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] + + # stub out get_users_who_share_room_with_user so that it claims that + # `@user2:host2` is in the room + def get_users_who_share_room_with_user(user_id): + return defer.succeed({"@user2:host2"}) + + hs.get_datastore().get_users_who_share_room_with_user = ( + get_users_who_share_room_with_user + ) + + # whenever send_transaction is called, record the edu data + self.edus = [] + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + def record_transaction(self, txn, json_cb): + data = json_cb() + self.edus.extend(data["edus"]) + return defer.succeed({}) + + def test_send_device_updates(self): + """Basic case: each device update should result in an EDU""" + # create a device + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + + # expect one edu + self.assertEqual(len(self.edus), 1) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # a second call should produce no new device EDUs + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + self.assertEqual(self.edus, []) + + # a second device + self.login("user", "pass", device_id="D2") + + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + def test_upload_signatures(self): + """Uploading signatures on some devices should produce updates for that user""" + + e2e_handler = self.hs.get_e2e_keys_handler() + + # register two devices + u1 = self.register_user("user", "pass") + self.login(u1, "pass", device_id="D1") + self.login(u1, "pass", device_id="D2") + + # expect two edus + self.assertEqual(len(self.edus), 2) + stream_id = None + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload signing keys for each device + device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1") + device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2") + + # expect two more edus + self.assertEqual(len(self.edus), 2) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + + # upload master key and self-signing key + master_signing_key = generate_self_id_key() + master_key = { + "user_id": u1, + "usage": ["master"], + "keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)}, + } + + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_signing_key = generate_self_id_key() + selfsigning_key = { + "user_id": u1, + "usage": ["self_signing"], + "keys": { + key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key) + }, + } + sign.sign_json(selfsigning_key, u1, master_signing_key) + + cross_signing_keys = { + "master_key": master_key, + "self_signing_key": selfsigning_key, + } + + self.get_success( + e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys) + ) + + # expect signing key update edu + self.assertEqual(len(self.edus), 1) + self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") + + # sign the devices + d1_json = build_device_dict(u1, "D1", device1_signing_key) + sign.sign_json(d1_json, u1, selfsigning_signing_key) + d2_json = build_device_dict(u1, "D2", device2_signing_key) + sign.sign_json(d2_json, u1, selfsigning_signing_key) + + ret = self.get_success( + e2e_handler.upload_signatures_for_device_keys( + u1, {u1: {"D1": d1_json, "D2": d2_json}}, + ) + ) + self.assertEqual(ret["failures"], {}) + + # expect two edus, in one or two transactions. We don't know what order the + # devices will be updated. + self.assertEqual(len(self.edus), 2) + stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142 + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + if stream_id is not None: + self.assertEqual(c["prev_id"], [stream_id]) + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2"}, devices) + + def test_delete_devices(self): + """If devices are deleted, that should result in EDUs too""" + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # expect three edus + self.assertEqual(len(self.edus), 3) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id) + stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + # expect three edus, in an unknown order + self.assertEqual(len(self.edus), 3) + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertGreaterEqual( + c.items(), + {"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(), + ) + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_unreachable_server(self): + """If the destination server is unreachable, all the updates should get sent on + recovery + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # for each device, there should be a single update + self.assertEqual(len(self.edus), 3) + stream_id = None + for edu in self.edus: + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else []) + if stream_id is not None: + self.assertGreaterEqual(c["stream_id"], stream_id) + stream_id = c["stream_id"] + devices = {edu["content"]["device_id"] for edu in self.edus} + self.assertEqual({"D1", "D2", "D3"}, devices) + + def test_prune_outbound_device_pokes1(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server has never been reachable. + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # there should be a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def test_prune_outbound_device_pokes2(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server was reachable, but then goes + offline. + """ + + # create first device + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + + # expect the update EDU + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # now the server goes offline + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 3) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # ... and we should get a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def check_device_update_edu( + self, + edu: JsonDict, + user_id: str, + device_id: str, + prev_stream_id: Optional[int], + ) -> int: + """Check that the given EDU is an update for the given device + Returns the stream_id. + """ + self.assertEqual(edu["edu_type"], "m.device_list_update") + content = edu["content"] + + expected = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_stream_id] if prev_stream_id is not None else [], + } + + self.assertLessEqual(expected.items(), content.items()) + if prev_stream_id is not None: + self.assertGreaterEqual(content["stream_id"], prev_stream_id) + return content["stream_id"] + + def check_signing_key_update_txn(self, txn: JsonDict,) -> None: + """Check that the txn has an EDU with a signing key update. + """ + edus = txn["edus"] + self.assertEqual(len(edus), 1) + + def generate_and_upload_device_signing_key( + self, user_id: str, device_id: str + ) -> SigningKey: + """Generate a signing keypair for the given device, and upload it""" + sk = key.generate_signing_key(device_id) + + device_dict = build_device_dict(user_id, device_id, sk) + + self.get_success( + self.hs.get_e2e_keys_handler().upload_keys_for_user( + user_id, device_id, {"device_keys": device_dict}, + ) + ) + return sk + + +def generate_self_id_key() -> SigningKey: + """generate a signing key whose version is its public key + + ... as used by the cross-signing-keys. + """ + k = key.generate_signing_key("x") + k.version = encode_pubkey(k) + return k + + +def key_id(k: BaseKey) -> str: + return "%s:%s" % (k.alg, k.version) + + +def encode_pubkey(sk: SigningKey) -> str: + """Encode the public key corresponding to the given signing key as base64""" + return key.encode_verify_key_base64(key.get_verify_key(sk)) + + +def build_device_dict(user_id: str, device_id: str, sk: SigningKey): + """Build a dict representing the given device""" + return { + "user_id": user_id, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "keys": { + "curve25519:" + device_id: "curve25519+key", + key_id(sk): encode_pubkey(sk), + }, + } diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b08be451aa..296dc887be 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +15,10 @@ # limitations under the License. import logging -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from tests import unittest @@ -41,8 +44,68 @@ class ServerACLsTestCase(unittest.TestCase): self.assertTrue(server_matches_acl_event("1:2:3:4", e)) +class StateQueryTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_without_event_id(self): + """ + Querying v1/state/<room_id> without an event ID will return the current + known state. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertEqual( + channel.json_body["room_version"], + self.hs.config.default_room_version.identifier, + ) + + members = set( + map( + lambda x: x["state_key"], + filter( + lambda x: x["type"] == "m.room.member", channel.json_body["pdus"] + ), + ) + ) + + self.assertEqual(members, {"@user:other.example.com", u1}) + self.assertEqual(len(channel.json_body["pdus"]), 6) + + def test_needs_to_be_in_room(self): + """ + Querying v1/state/<room_id> requires the server + be in the room to provide data. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + + def _create_acl_event(content): - return FrozenEvent( + return make_event_from_dict( { "room_id": "!a:b", "event_id": "$a:b", diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py new file mode 100644 index 0000000000..27d83bb7d9 --- /dev/null +++ b/tests/federation/transport/test_server.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.unittest import override_config + + +class RoomDirectoryFederationTests(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + @override_config({"allow_public_rooms_over_federation": False}) + def test_blocked_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(403, channel.code) + + @override_config({"allow_public_rooms_over_federation": True}) + def test_open_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b03103d96f..c01b04e1dc 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() + # MAU tests - self.hs.config.max_mau_value = 50 + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking._max_mau_value = 50 + self.small_number_of_users = 1 self.large_number_of_users = 100 @@ -82,16 +87,16 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @defer.inlineCallbacks @@ -99,8 +104,10 @@ class AuthTestCase(unittest.TestCase): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) self.assertEqual("a_user", user_id) @@ -109,99 +116,121 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_disabled(self): - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_exceeded_large(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_parity(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) def _get_macaroon(self): diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a3aa0a1cf2..62b47f6574 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -160,6 +160,24 @@ class DeviceTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") + def test_update_device_too_long_display_name(self): + """Update a device with a display name that is invalid (too long).""" + self._record_users() + + # Request to update a device display name with a new value that is longer than allowed. + update = { + "display_name": "a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) + } + self.get_failure( + self.handler.update_device(user1, "abc", update), + synapse.api.errors.SynapseError, + ) + + # Ensure the display name was not updated. + res = self.get_success(self.handler.get_device(user1, "abc")) + self.assertEqual(res["display_name"], "display 2") + def test_update_unknown_device(self): update = {"display_name": "new_display"} res = self.handler.update_device("user_id", "unknown_device_id", update) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 91c7a17070..00bb776271 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,25 +18,20 @@ from mock import Mock from twisted.internet import defer +import synapse +import synapse.api.errors +from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.handlers.directory import DirectoryHandler -from synapse.rest.client.v1 import directory, room -from synapse.types import RoomAlias +from synapse.rest.client.v1 import directory, login, room +from synapse.types import RoomAlias, create_requester from tests import unittest -from tests.utils import setup_test_homeserver -class DirectoryHandlers(object): - def __init__(self, hs): - self.directory_handler = DirectoryHandler(hs) - - -class DirectoryTestCase(unittest.TestCase): +class DirectoryTestCase(unittest.HomeserverTestCase): """ Tests the directory service. """ - @defer.inlineCallbacks - def setUp(self): + def make_homeserver(self, reactor, clock): self.mock_federation = Mock() self.mock_registry = Mock() @@ -47,14 +42,12 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler - hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, ) - hs.handlers = DirectoryHandlers(hs) self.handler = hs.get_handlers().directory_handler @@ -64,23 +57,25 @@ class DirectoryTestCase(unittest.TestCase): self.your_room = RoomAlias.from_string("#your-room:test") self.remote_room = RoomAlias.from_string("#another:remote") - @defer.inlineCallbacks + return hs + def test_get_local_association(self): - yield self.store.create_room_alias_association( - self.my_room, "!8765qwer:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.my_room, "!8765qwer:test", ["test"] + ) ) - result = yield self.handler.get_association(self.my_room) + result = self.get_success(self.handler.get_association(self.my_room)) self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - @defer.inlineCallbacks def test_get_remote_association(self): self.mock_federation.make_query.return_value = defer.succeed( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) - result = yield self.handler.get_association(self.remote_room) + result = self.get_success(self.handler.get_association(self.remote_room)) self.assertEquals( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result @@ -93,19 +88,303 @@ class DirectoryTestCase(unittest.TestCase): ignore_backoff=True, ) - @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) ) - response = yield self.query_handlers["directory"]( - {"room_alias": "#your-room:test"} + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class TestCreateAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_handlers().directory_handler + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def test_create_alias_joined_room(self): + """A user can create an alias for a room they're in.""" + self.get_success( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, self.room_id, + ) + ) + + def test_create_alias_other_room(self): + """A user cannot create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.get_failure( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, other_room_id, + ), + synapse.api.errors.SynapseError, + ) + + def test_create_alias_admin(self): + """An admin can create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.test_user, tok=self.test_user_tok + ) + + self.get_success( + self.handler.create_association( + create_requester(self.admin_user), self.room_alias, other_room_id, + ) + ) + + +class TestDeleteAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def _create_alias(self, user): + # Create a new alias to this room. + self.get_success( + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], user + ) + ) + + def test_delete_alias_not_allowed(self): + """A user that doesn't meet the expected guidelines cannot delete an alias.""" + self._create_alias(self.admin_user) + self.get_failure( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ), + synapse.api.errors.AuthError, + ) + + def test_delete_alias_creator(self): + """An alias creator can delete their own alias.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) + ) + self.assertEquals(self.room_id, result) + + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_admin(self): + """A server admin can delete an alias created by another user.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias as the admin. + result = self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) + ) + self.assertEquals(self.room_id, result) + + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_sufficient_power(self): + """A user with a sufficient power level should be able to delete an alias.""" + self._create_alias(self.admin_user) + + # Increase the user's power level. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"users": {self.test_user: 100}}, + tok=self.admin_user_tok, + ) + + # They can now delete the alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) + ) + self.assertEquals(self.room_id, result) + + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + +class CanonicalAliasTestCase(unittest.HomeserverTestCase): + """Test modifications of the canonical alias when delete aliases. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = self._add_alias(self.test_alias) + + def _add_alias(self, alias: str) -> RoomAlias: + """Add an alias to the test room.""" + room_alias = RoomAlias.from_string(alias) + + # Create a new alias to this room. + self.get_success( + self.store.create_room_alias_association( + room_alias, self.room_id, ["test"], self.admin_user + ) + ) + return room_alias + + def _set_canonical_alias(self, content): + """Configure the canonical alias state on the room.""" + self.helper.send_state( + self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, + ) + + def _get_canonical_alias(self): + """Get the canonical alias state of the room.""" + return self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self._set_canonical_alias( + {"alias": self.test_alias, "alt_aliases": [self.test_alias]} + ) + + data = self._get_canonical_alias() + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + # Finally, delete the alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) + ) + + data = self._get_canonical_alias() + self.assertNotIn("alias", data["content"]) + self.assertNotIn("alt_aliases", data["content"]) + + def test_remove_other_alias(self): + """Removing an alias listed as in alt_aliases should remove it there too.""" + # Create a second alias. + other_test_alias = "#test2:test" + other_room_alias = self._add_alias(other_test_alias) + + # Set the alias as the canonical alias for this room. + self._set_canonical_alias( + { + "alias": self.test_alias, + "alt_aliases": [self.test_alias, other_test_alias], + } + ) + + data = self._get_canonical_alias() + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual( + data["content"]["alt_aliases"], [self.test_alias, other_test_alias] + ) + + # Delete the second alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), other_room_alias + ) + ) + + data = self._get_canonical_alias() + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + class TestCreateAliasACL(unittest.HomeserverTestCase): user_id = "@test:test" diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8dccc6826e..e1e144b2e7 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,9 +17,11 @@ import mock +import signedjson.key as key +import signedjson.sign as sign + from twisted.internet import defer -import synapse.api.errors import synapse.handlers.e2e_keys import synapse.storage from synapse.api import errors @@ -145,3 +149,357 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, }, ) + + @defer.inlineCallbacks + def test_replace_master_key(self): + """uploading a new signing key should make the old signing key unavailable""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + keys2 = { + "master_key": { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw": "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys2) + + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, local_user + ) + self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + + @defer.inlineCallbacks + def test_reupload_signatures(self): + """re-uploading a signature should not fail""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ": "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + }, + }, + "self_signing_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["self_signing"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + }, + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", + "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ", + "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8", + ) + sign.sign_json(keys1["self_signing_key"], local_user, master_signing_key) + signing_key = key.decode_signing_key_base64( + "ed25519", + "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", + ) + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + # upload two device keys, which will be signed later by the self-signing key + device_key_1 = { + "user_id": local_user, + "device_id": "abc", + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "keys": { + "ed25519:abc": "base64+ed25519+key", + "curve25519:abc": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, + } + device_key_2 = { + "user_id": local_user, + "device_id": "def", + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "keys": { + "ed25519:def": "base64+ed25519+key", + "curve25519:def": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:def": "base64+signature"}}, + } + + yield self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) + yield self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) + + # sign the first device key and upload it + del device_key_1["signatures"] + sign.sign_json(device_key_1, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) + + # sign the second device key and upload both device keys. The server + # should ignore the first device key since it already has a valid + # signature for it + del device_key_2["signatures"] + sign.sign_json(device_key_2, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) + + device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" + device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, local_user + ) + del devices["device_keys"][local_user]["abc"]["unsigned"] + del devices["device_keys"][local_user]["def"]["unsigned"] + self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) + self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) + + @defer.inlineCallbacks + def test_self_signing_key_doesnt_show_up_as_device(self): + """signing keys should be hidden when fetching a user's devices""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + res = None + try: + yield self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 400) + + res = yield self.handler.query_local_devices({local_user: None}) + self.assertDictEqual(res, {local_user: {}}) + + @defer.inlineCallbacks + def test_upload_signatures(self): + """should check signatures that are uploaded""" + # set up a user with cross-signing keys and a device. This user will + # try uploading signatures + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" + device_key = { + "user_id": local_user, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey}, + "signatures": {local_user: {"ed25519:xyz": "something"}}, + } + device_signing_key = key.decode_signing_key_base64( + "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" + ) + + yield self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) + + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + master_key = { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", master_pubkey, "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" + ) + usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + usersigning_key = { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["user_signing"], + "keys": {"ed25519:" + usersigning_pubkey: usersigning_pubkey}, + } + usersigning_signing_key = key.decode_signing_key_base64( + "ed25519", usersigning_pubkey, "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" + ) + sign.sign_json(usersigning_key, local_user, master_signing_key) + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_pubkey = "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + selfsigning_key = { + "user_id": local_user, + "usage": ["self_signing"], + "keys": {"ed25519:" + selfsigning_pubkey: selfsigning_pubkey}, + } + selfsigning_signing_key = key.decode_signing_key_base64( + "ed25519", selfsigning_pubkey, "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" + ) + sign.sign_json(selfsigning_key, local_user, master_signing_key) + cross_signing_keys = { + "master_key": master_key, + "user_signing_key": usersigning_key, + "self_signing_key": selfsigning_key, + } + yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + + # set up another user with a master key. This user will be signed by + # the first user + other_user = "@otherboris:" + self.hs.hostname + other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" + other_master_key = { + # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI + "user_id": other_user, + "usage": ["master"], + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, + } + yield self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) + + # test various signature failures (see below) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, + }, + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, + }, + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, + }, + }, + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, + }, + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, + "something": "random", + "signatures": { + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, + }, + }, + }, + ) + + user_failures = ret["failures"][local_user] + self.assertEqual( + user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE + ) + self.assertEqual( + user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE + ) + self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + + other_user_failures = ret["failures"][other_user] + self.assertEqual( + other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND + ) + self.assertEqual( + other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN + ) + + # test successful signatures + del device_key["signatures"] + sign.sign_json(device_key, local_user, selfsigning_signing_key) + sign.sign_json(master_key, local_user, device_signing_key) + sign.sign_json(other_master_key, local_user, usersigning_signing_key) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) + + self.assertEqual(ret["failures"], {}) + + # fetch the signed keys/devices and make sure that the signatures are there + ret = yield self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) + + self.assertEqual( + ret["device_keys"][local_user]["xyz"]["signatures"][local_user][ + "ed25519:" + selfsigning_pubkey + ], + device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey], + ) + self.assertEqual( + ret["master_keys"][local_user]["signatures"][local_user][ + "ed25519:" + device_id + ], + master_key["signatures"][local_user]["ed25519:" + device_id], + ) + self.assertEqual( + ret["master_keys"][other_user]["signatures"][local_user][ + "ed25519:" + usersigning_pubkey + ], + other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], + ) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index c4503c1611..70f172eb02 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_etag = res["etag"] + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["etag"], version_etag) + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -187,9 +198,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, 404) @defer.inlineCallbacks - def test_update_bad_version(self): - """Check that we get a 400 if the version in the body is missing or - doesn't match + def test_update_omitted_version(self): + """Check that the update succeeds if the version is missing from the body """ version = yield self.handler.create_version( self.local_user, @@ -197,19 +207,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = None - try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 400) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + del res["etag"] # etag is opaque, so don't test its contents + self.assertDictEqual( + res, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + "count": 0, + }, + ) + + @defer.inlineCallbacks + def test_update_bad_version(self): + """Check that we get a 400 if the version in the body doesn't match + """ + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) + self.assertEqual(version, "1") res = None try: @@ -394,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the etag to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_etag = res["etag"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -408,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -417,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["etag"], backup_etag) + backup_etag = res["etag"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -428,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py new file mode 100644 index 0000000000..96fea58673 --- /dev/null +++ b/tests/handlers/test_federation.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from unittest import TestCase + +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.federation.federation_base import event_from_pdu_json +from synapse.logging.context import LoggingContext, run_in_background +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + +logger = logging.getLogger(__name__) + + +class FederationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver(http_client=None) + self.handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + return hs + + def test_exchange_revoked_invite(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + # Send a 3PID invite event with an empty body so it's considered as a revoked one. + invite_token = "sometoken" + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body={}, + tok=tok, + ) + + d = self.handler.on_exchange_third_party_invite_request( + room_id=room_id, + event_dict={ + "type": EventTypes.Member, + "room_id": room_id, + "sender": user_id, + "state_key": "@someone:example.org", + "content": { + "membership": "invite", + "third_party_invite": { + "display_name": "alice", + "signed": { + "mxid": "@alice:localhost", + "token": invite_token, + "signatures": { + "magic.forest": { + "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg" + } + }, + }, + }, + }, + }, + ) + + failure = self.get_failure(d, AuthError).value + + self.assertEqual(failure.code, 403, failure) + self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) + self.assertEqual(failure.msg, "You are not invited to this room.") + + def test_rejected_message_event_state(self): + """ + Check that we store the state group correctly for rejected non-state events. + + Regression test for #6289. + """ + OTHER_SERVER = "otherserver" + OTHER_USER = "@otheruser:" + OTHER_SERVER + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # pretend that another server has joined + join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) + + # check the state group + sg = self.successResultOf( + self.store._get_state_group_for_event(join_event.event_id) + ) + + # build and send an event which will be rejected + ev = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {}, + "room_id": room_id, + "sender": "@yetanotheruser:" + OTHER_SERVER, + "depth": join_event["depth"] + 1, + "prev_events": [join_event.event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + with LoggingContext(request="send_rejected"): + d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + self.get_success(d) + + # that should have been rejected + e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True)) + self.assertIsNotNone(e.rejected_reason) + + # ... and the state group should be the same as before + sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id)) + + self.assertEqual(sg, sg2) + + def test_rejected_state_event_state(self): + """ + Check that we store the state group correctly for rejected state events. + + Regression test for #6289. + """ + OTHER_SERVER = "otherserver" + OTHER_USER = "@otheruser:" + OTHER_SERVER + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # pretend that another server has joined + join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) + + # check the state group + sg = self.successResultOf( + self.store._get_state_group_for_event(join_event.event_id) + ) + + # build and send an event which will be rejected + ev = event_from_pdu_json( + { + "type": "org.matrix.test", + "state_key": "test_key", + "content": {}, + "room_id": room_id, + "sender": "@yetanotheruser:" + OTHER_SERVER, + "depth": join_event["depth"] + 1, + "prev_events": [join_event.event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + with LoggingContext(request="send_rejected"): + d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + self.get_success(d) + + # that should have been rejected + e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True)) + self.assertIsNotNone(e.rejected_reason) + + # ... and the state group should be the same as before + sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id)) + + self.assertEqual(sg, sg2) + + def _build_and_send_join_event(self, other_server, other_user, room_id): + join_event = self.get_success( + self.handler.on_make_join_request(other_server, room_id, other_user) + ) + # the auth code requires that a signature exists, but doesn't check that + # signature... go figure. + join_event.signatures[other_server] = {"x": "y"} + with LoggingContext(request="send_join"): + d = run_in_background( + self.handler.on_send_join_request, other_server, join_event + ) + self.get_success(d) + + # sanity-check: the room should show that the new user is a member + r = self.get_success(self.store.get_current_state_ids(room_id)) + self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) + + return join_event + + +class EventFromPduTestCase(TestCase): + def test_valid_json(self): + """Valid JSON should be turned into an event.""" + ev = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"bool": True, "null": None, "int": 1, "str": "foobar"}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.V6, + ) + + self.assertIsInstance(ev, EventBase) + + def test_invalid_numbers(self): + """Invalid values for an integer should be rejected, all floats should be rejected.""" + for value in [ + -(2 ** 53), + 2 ** 53, + 1.0, + float("inf"), + float("-inf"), + float("nan"), + ]: + with self.assertRaises(SynapseError): + event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"foo": value}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.V6, + ) + + def test_invalid_nested(self): + """List and dictionaries are recursively searched.""" + with self.assertRaises(SynapseError): + event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"foo": [{"bar": 2 ** 56}]}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.V6, + ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py new file mode 100644 index 0000000000..1bb25ab684 --- /dev/null +++ b/tests/handlers/test_oidc.py @@ -0,0 +1,570 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from urllib.parse import parse_qs, urlparse + +from mock import Mock, patch + +import attr +import pymacaroons + +from twisted.internet import defer +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone + +from synapse.handlers.oidc_handler import ( + MappingException, + OidcError, + OidcHandler, + OidcMappingProvider, +) +from synapse.types import UserID + +from tests.unittest import HomeserverTestCase, override_config + + +@attr.s +class FakeResponse: + code = attr.ib() + body = attr.ib() + phrase = attr.ib() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) + + +# These are a few constants that are used as config parameters in the tests. +ISSUER = "https://issuer/" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" +BASE_URL = "https://synapse/" +CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +SCOPES = ["openid"] + +AUTHORIZATION_ENDPOINT = ISSUER + "authorize" +TOKEN_ENDPOINT = ISSUER + "token" +USERINFO_ENDPOINT = ISSUER + "userinfo" +WELL_KNOWN = ISSUER + ".well-known/openid-configuration" +JWKS_URI = ISSUER + ".well-known/jwks.json" + +# config for common cases +COMMON_CONFIG = { + "discover": False, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, +} + + +# The cookie name and path don't really matter, just that it has to be coherent +# between the callback & redirect handlers. +COOKIE_NAME = b"oidc_session" +COOKIE_PATH = "/_synapse/oidc" + +MockedMappingProvider = Mock(OidcMappingProvider) + + +def simple_async_mock(return_value=None, raises=None): + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + +async def get_json(url): + # Mock get_json calls to handle jwks & oidc discovery endpoints + if url == WELL_KNOWN: + # Minimal discovery document, as defined in OpenID.Discovery + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata + return { + "issuer": ISSUER, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, + "userinfo_endpoint": USERINFO_ENDPOINT, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + elif url == JWKS_URI: + return {"keys": []} + + +class OidcHandlerTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + config = self.default_config() + config["public_baseurl"] = BASE_URL + oidc_config = config.get("oidc_config", {}) + oidc_config["enabled"] = True + oidc_config["client_id"] = CLIENT_ID + oidc_config["client_secret"] = CLIENT_SECRET + oidc_config["issuer"] = ISSUER + oidc_config["scopes"] = SCOPES + oidc_config["user_mapping_provider"] = { + "module": __name__ + ".MockedMappingProvider" + } + config["oidc_config"] = oidc_config + + hs = self.setup_test_homeserver( + http_client=self.http_client, + proxied_http_client=self.http_client, + config=config, + ) + + self.handler = OidcHandler(hs) + + return hs + + def metadata_edit(self, values): + return patch.dict(self.handler._provider_metadata, values) + + def assertRenderedError(self, error, error_description=None): + args = self.handler._render_error.call_args[0] + self.assertEqual(args[1], error) + if error_description is not None: + self.assertEqual(args[2], error_description) + # Reset the render_error mock + self.handler._render_error.reset_mock() + + def test_config(self): + """Basic config correctly sets up the callback URL and client auth correctly.""" + self.assertEqual(self.handler._callback_url, CALLBACK_URL) + self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + + @override_config({"oidc_config": {"discover": True}}) + @defer.inlineCallbacks + def test_discovery(self): + """The handler should discover the endpoints from OIDC discovery document.""" + # This would throw if some metadata were invalid + metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + + self.assertEqual(metadata.issuer, ISSUER) + self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) + self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) + self.assertEqual(metadata.jwks_uri, JWKS_URI) + # FIXME: it seems like authlib does not have that defined in its metadata models + # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + + # subsequent calls should be cached + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_no_discovery(self): + """When discovery is disabled, it should not try to load from discovery document.""" + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_load_jwks(self): + """JWKS loading is done once (then cached) if used.""" + jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.assertEqual(jwks, {"keys": []}) + + # subsequent calls should be cached… + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_not_called() + + # …unless forced + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + + # Throw if the JWKS uri is missing + with self.metadata_edit({"jwks_uri": None}): + with self.assertRaises(RuntimeError): + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + + # Return empty key set if JWKS are not used + self.handler._scopes = [] # not asking the openid scope + self.http_client.get_json.reset_mock() + jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_not_called() + self.assertEqual(jwks, {"keys": []}) + + @override_config({"oidc_config": COMMON_CONFIG}) + def test_validate_config(self): + """Provider metadatas are extensively validated.""" + h = self.handler + + # Default test config does not throw + h._validate_metadata() + + with self.metadata_edit({"issuer": None}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "http://insecure/"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "https://invalid/?because=query"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"authorization_endpoint": None}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"token_endpoint": None}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"token_endpoint": "http://insecure/token"}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": None}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"response_types_supported": ["id_token"]}): + self.assertRaisesRegex( + ValueError, "response_types_supported", h._validate_metadata + ) + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} + ): + # should not throw, as client_secret_basic is the default auth method + h._validate_metadata() + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_post"]} + ): + self.assertRaisesRegex( + ValueError, + "token_endpoint_auth_methods_supported", + h._validate_metadata, + ) + + # Tests for configs that the userinfo endpoint + self.assertFalse(h._uses_userinfo) + h._scopes = [] # do not request the openid scope + self.assertTrue(h._uses_userinfo) + self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) + + with self.metadata_edit( + {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None} + ): + # Shouldn't raise with a valid userinfo, even without + h._validate_metadata() + + @override_config({"oidc_config": {"skip_verification": True}}) + def test_skip_verification(self): + """Provider metadata validation can be disabled by config.""" + with self.metadata_edit({"issuer": "http://insecure"}): + # This should not throw + self.handler._validate_metadata() + + @defer.inlineCallbacks + def test_redirect_request(self): + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["addCookie"]) + url = yield defer.ensureDeferred( + self.handler.handle_redirect_request(req, b"http://client/redirect") + ) + url = urlparse(url) + auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + + self.assertEqual(url.scheme, auth_endpoint.scheme) + self.assertEqual(url.netloc, auth_endpoint.netloc) + self.assertEqual(url.path, auth_endpoint.path) + + params = parse_qs(url.query) + self.assertEqual(params["redirect_uri"], [CALLBACK_URL]) + self.assertEqual(params["response_type"], ["code"]) + self.assertEqual(params["scope"], [" ".join(SCOPES)]) + self.assertEqual(params["client_id"], [CLIENT_ID]) + self.assertEqual(len(params["state"]), 1) + self.assertEqual(len(params["nonce"]), 1) + + # Check what is in the cookie + # note: python3.5 mock does not have the .called_once() method + calls = req.addCookie.call_args_list + self.assertEqual(len(calls), 1) # called once + # For some reason, call.args does not work with python3.5 + args = calls[0][0] + kwargs = calls[0][1] + self.assertEqual(args[0], COOKIE_NAME) + self.assertEqual(kwargs["path"], COOKIE_PATH) + cookie = args[1] + + macaroon = pymacaroons.Macaroon.deserialize(cookie) + state = self.handler._get_value_from_macaroon(macaroon, "state") + nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") + redirect = self.handler._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + + self.assertEqual(params["state"], [state]) + self.assertEqual(params["nonce"], [nonce]) + self.assertEqual(redirect, "http://client/redirect") + + @defer.inlineCallbacks + def test_callback_error(self): + """Errors from the provider returned in the callback are displayed.""" + self.handler._render_error = Mock() + request = Mock(args={}) + request.args[b"error"] = [b"invalid_client"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "") + + request.args[b"error_description"] = [b"some description"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "some description") + + @defer.inlineCallbacks + def test_callback(self): + """Code callback works and display errors if something went wrong. + + A lot of scenarios are tested here: + - when the callback works, with userinfo from ID token + - when the user mapping fails + - when ID token verification fails + - when the callback works, with userinfo fetched from the userinfo endpoint + - when the userinfo fetching fails + - when the code exchange fails + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "preferred_username": "bar", + } + user_id = UserID("foo", "domain.org") + self.handler._render_error = Mock(return_value=None) + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock(spec=["args", "getCookie", "addCookie"]) + + code = "code" + state = "state" + nonce = "nonce" + client_redirect_url = "http://client/redirect" + session = self.handler._generate_oidc_session_token( + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + request.getCookie.return_value = session + + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_not_called() + self.handler._render_error.assert_not_called() + + # Handle mapping errors + self.handler._map_userinfo_to_user = simple_async_mock( + raises=MappingException() + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + + # Handle ID token errors + self.handler._parse_id_token = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_token") + + self.handler._auth_handler.complete_sso_login.reset_mock() + self.handler._exchange_code.reset_mock() + self.handler._parse_id_token.reset_mock() + self.handler._map_userinfo_to_user.reset_mock() + self.handler._fetch_userinfo.reset_mock() + + # With userinfo fetching + self.handler._scopes = [] # do not ask the "openid" scope + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_not_called() + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_called_once_with(token) + self.handler._render_error.assert_not_called() + + # Handle userinfo fetching error + self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("fetch_error") + + # Handle code exchange failure + self.handler._exchange_code = simple_async_mock( + raises=OidcError("invalid_request") + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @defer.inlineCallbacks + def test_callback_session(self): + """The callback verifies the session presence and validity""" + self.handler._render_error = Mock(return_value=None) + request = Mock(spec=["args", "getCookie", "addCookie"]) + + # Missing cookie + request.args = {} + request.getCookie.return_value = None + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("missing_session", "No session cookie found") + + # Missing session parameter + request.args = {} + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request", "State parameter is missing") + + # Invalid cookie + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_session") + + # Mismatching session + session = self.handler._generate_oidc_session_token( + state="state", + nonce="nonce", + client_redirect_url="http://client/redirect", + ui_auth_session_id=None, + ) + request.args = {} + request.args[b"state"] = [b"mismatching state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mismatching_session") + + # Valid session + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) + @defer.inlineCallbacks + def test_exchange_code(self): + """Code exchange behaves correctly and handles various error scenarios.""" + token = {"type": "bearer"} + token_json = json.dumps(token).encode("utf-8") + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + ) + code = "code" + ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + kwargs = self.http_client.request.call_args[1] + + self.assertEqual(ret, token) + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["client_secret"], [CLIENT_SECRET]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + # Test error handling + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=400, + phrase=b"Bad Request", + body=b'{"error": "foo", "error_description": "bar"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "foo") + self.assertEqual(exc.exception.error_description, "bar") + + # Internal server error with no JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, phrase=b"Internal Server Error", body=b"Not JSON", + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # Internal server error with JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, + phrase=b"Internal Server Error", + body=b'{"error": "internal_server_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "internal_server_error") + + # 4xx error without "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # 2xx error with "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=b'{"error": "some_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "some_error") diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index f70c6e7d65..05ea40a7de 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,9 +19,10 @@ from mock import Mock, call from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState -from synapse.events import room_version_to_event_format +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( + EXTERNAL_PROCESS_EXPIRY, FEDERATION_PING_INTERVAL, FEDERATION_TIMEOUT, IDLE_TIMER, @@ -337,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set([user_id]), now=now + state, is_mine=True, syncing_user_ids={user_id}, now=now ) self.assertIsNotNone(new_state) @@ -413,6 +414,44 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEquals(state, new_state) +class PresenceHandlerTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + + def test_external_process_timeout(self): + """Test that if an external process doesn't update the records for a while + we time out their syncing users presence. + """ + process_id = 1 + user_id = "@test:server" + + # Notify handler that a user is now syncing. + self.get_success( + self.presence_handler.update_external_syncs_row( + process_id, user_id, True, self.clock.time_msec() + ) + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + # Check that if the external process timeout fires, then the syncing + # user gets timed out + self.reactor.advance(EXTERNAL_PROCESS_EXPIRY) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + class PresenceJoinTestCase(unittest.HomeserverTestCase): """Tests remote servers get told about presence of users in the room when they join and when new local users join. @@ -455,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, "@test2:server") # Mark test2 as online, test will be offline with a last_active of 0 - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -504,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) # Mark test as online - self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + ) ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) # Add servers to the room @@ -540,7 +585,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), states=[expected_state] + destinations={"server2", "server3"}, states=[expected_state] ) def _add_new_user(self, room_id, user_id): @@ -549,7 +594,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): hostname = get_domain_from_id(user_id) - room_version = self.get_success(self.store.get_room_version(room_id)) + room_version = self.get_success(self.store.get_room_version_id(room_id)) builder = EventBuilder( state=self.state, @@ -558,7 +603,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): clock=self.clock, hostname=hostname, signing_key=self.random_signing_key, - format_version=room_version_to_event_format(room_version), + room_version=KNOWN_ROOM_VERSIONS[room_version], room_id=room_id, type=EventTypes.Member, sender=user_id, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..29dd7d9c6e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,12 +14,12 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") @@ -70,6 +66,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -81,19 +78,58 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), "Frank Jr.", ) + # Set displayname again + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) + ) + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.hs.config.enable_set_displayname = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + # Setting displayname a second time is forbidden + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + ) ) yield self.assertFailure(d, AuthError) @@ -137,13 +173,54 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) self.assertEquals( (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + # Set avatar again + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.hs.config.enable_set_avatar_url = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + # Set avatar a second time is forbidden + d = defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + ) + + yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1e9ba3a201..ca32f993a3 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() # some of the tests rely on us having a user consent version hs_config["user_consent"] = { @@ -135,6 +135,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart="local_part"), ResourceLimitError ) + def test_auto_join_rooms_for_guests(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + self.hs.config.auto_join_rooms_for_guests = False + user_id = self.get_success( + self.handler.register_user(localpart="jeff", make_guest=True), + ) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 0) + def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] @@ -175,7 +185,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=False) + self.store.is_real_user = Mock(return_value=defer.succeed(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -187,8 +197,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=1) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(1)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -202,8 +212,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=2) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(2)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -256,8 +266,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) - @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, password_hash=None): + async def get_or_create_user( + self, requester, localpart, displayname, password_hash=None + ): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -269,16 +280,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): one will be randomly generated. Returns: A tuple of (user_id, access_token). - Raises: - RegistrationError if there was a problem registering. """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth().check_auth_blocking() need_register = True try: - yield self.handler.check_username(localpart) + await self.handler.check_username(localpart) except SynapseError as e: if e.errcode == Codes.USER_IN_USE: need_register = False @@ -290,21 +299,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): token = self.macaroon_generator.generate_access_token(user_id) if need_register: - yield self.handler.register_with_store( + await self.handler.register_with_store( user_id=user_id, password_hash=password_hash, create_profile_with_displayname=user.localpart, ) else: - yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None ) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) - yield self.hs.get_profile_handler().set_displayname( + await self.hs.get_profile_handler().set_displayname( user, requester, displayname, by_admin=True ) diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py deleted file mode 100644 index 61eebb6985..0000000000 --- a/tests/handlers/test_roomlist.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.handlers.room_list import RoomListNextBatch - -import tests.unittest -import tests.utils - - -class RoomListTestCase(tests.unittest.TestCase): - """ Tests RoomList's RoomListNextBatch. """ - - def setUp(self): - pass - - def test_check_read_batch_tokens(self): - batch_token = RoomListNextBatch( - stream_ordering="abcdef", - public_room_stream_id="123", - current_limit=20, - direction_is_forward=True, - ).to_token() - next_batch = RoomListNextBatch.from_token(batch_token) - self.assertEquals(next_batch.stream_ordering, "abcdef") - self.assertEquals(next_batch.public_room_stream_id, "123") - self.assertEquals(next_batch.current_limit, 20) - self.assertEquals(next_batch.direction_is_forward, True) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 7569b6fab5..d9d312f0fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse import storage from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.storage.data_stores.main import stats from tests import unittest @@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,21 +82,21 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store._simple_select_list( + return self.store.db.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) def _get_current_stats(self, stats_type, stat_id): - table, id_col = storage.stats.TYPE_TO_TABLE[stats_type] + table, id_col = stats.TYPE_TO_TABLE[stats_type] - cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( - storage.stats.PER_SLICE_FIELDS[stats_type] + cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( + stats.PER_SLICE_FIELDS[stats_type] ) end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store._simple_select_one( + self.store.db.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_initial_room(self): """ @@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r = self.get_success(self.get_all_room_state()) @@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_update_one( + self.store.db.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now, before the table is actually ingested, add some more events. self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) @@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - self.store._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + self.store.db.updates._all_done = False + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) self.reactor.advance(86401) @@ -607,6 +623,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): """ self.hs.config.stats_enabled = False + self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") u1token = self.login("u1", "pass") @@ -618,6 +635,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNone(self._get_current_stats("user", u1)) self.hs.config.stats_enabled = True + self.handler.stats_enabled = True self._perform_background_initial_update() @@ -651,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -671,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -683,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -693,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -703,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 31f54bbd7d..e178d7765b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,54 +12,56 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION -from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.handlers.sync import SyncConfig from synapse.types import UserID import tests.unittest import tests.utils -from tests.utils import setup_test_homeserver -class SyncTestCase(tests.unittest.TestCase): +class SyncTestCase(tests.unittest.HomeserverTestCase): """ Tests Sync Handler. """ - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.sync_handler = SyncHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.hs = hs + self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks - def test_wait_for_sync_for_user_auth_blocking(self): + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking - user_id1 = "@user1:server" - user_id2 = "@user2:server" + def test_wait_for_sync_for_user_auth_blocking(self): + user_id1 = "@user1:test" + user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.reactor.advance(100) # So we get not 0 time + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 # Check that the happy case does not throw errors - yield self.store.upsert_monthly_active_user(user_id1) - yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works - self.hs.config.hs_disabled = True - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.auth_blocking._hs_disabled = True + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.hs.config.hs_disabled = False + self.auth_blocking._hs_disabled = False sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1f2ef5d01f..2fa8d4739b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.unittest import override_config from tests.utils import register_federation_servlets # Some local users to test with @@ -63,34 +64,39 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + datastores = Mock() + datastores.main = Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + "maybe_store_room_on_invite", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + "get_device_updates_by_remote", + ] + ) + + # the tests assume that we are starting at unix time 1000 + reactor.pump((1000,)) + hs = self.setup_test_homeserver( - datastore=( - Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_devices_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - ) - ), notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring, + replication_streams={}, ) + hs.datastores = datastores + return hs def prepare(self, reactor, clock, hs): - # the tests assume that we are starting at unix time 1000 - reactor.pump((1000,)) - mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -109,7 +115,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_devices_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + (0, []) + ) def get_received_txn_response(*args): return defer.succeed(None) @@ -118,19 +126,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_joined_room(room_id, user_id): + def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - hs.get_auth().check_joined_room = check_joined_room + hs.get_auth().check_user_in_room = check_user_in_room def get_joined_hosts_for_room(room_id): - return set(member.domain for member in self.room_members) + return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_users_in_room(room_id): - return set(str(u) for u in self.room_members) + return {str(u) for u in self.room_members} hs.get_state_handler().get_current_users_in_room = get_current_users_in_room @@ -139,11 +147,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): defer.succeed(1) ) - self.datastore.get_current_state_deltas.return_value = None + self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + ([], 0) + ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + None + ) def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] @@ -159,7 +172,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -171,6 +186,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] @@ -222,7 +238,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -234,6 +252,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] @@ -242,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) + self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} self.assertEquals(self.event_source.get_current_key(), 0) @@ -273,7 +292,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -294,7 +315,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -311,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -329,7 +354,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c5e91a8c41..c15bce5bef 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -14,6 +14,8 @@ # limitations under the License. from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.rest.client.v1 import login, room @@ -75,18 +77,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - self.store.remove_from_user_dir = Mock() - self.store.remove_from_user_in_public_room = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(s_user_id)) self.store.remove_from_user_dir.not_called() - self.store.remove_from_user_in_public_room.not_called() def test_handle_user_deactivated_regular_user(self): r_user_id = "@regular:test" self.get_success( self.store.register_user(user_id=r_user_id, password_hash=None) ) - self.store.remove_from_user_dir = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(r_user_id)) self.store.remove_from_user_dir.called_once_with(r_user_id) @@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -147,6 +147,98 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) + def test_spam_checker(self): + """ + A user which fails to the spam checks will not appear in search results. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + ) + self.assertEqual(public_users, []) + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that does not filter any users. + spam_checker = self.hs.get_spam_checker() + + class AllowAll(object): + def check_username_for_spam(self, user_profile): + # Allow all users. + return False + + spam_checker.spam_checkers = [AllowAll()] + + # The results do not change: + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that filters all users. + class BlockAll(object): + def check_username_for_spam(self, user_profile): + # All users are spammy. + return True + + spam_checker.spam_checkers = [BlockAll()] + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + def test_legacy_spam_checker(self): + """ + A spam checker without the expected method should be ignored. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + ) + self.assertEqual(public_users, []) + + # Configure a spam checker. + spam_checker = self.hs.get_spam_checker() + # The spam checker doesn't need any methods, so create a bare object. + spam_checker.spam_checker = object() + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + def _compress_shared(self, shared): """ Compress a list of users who share rooms dicts to a list of tuples. @@ -158,7 +250,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +261,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -181,10 +273,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +285,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +295,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +305,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -255,19 +347,23 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), set([(u1, room), (u2, room)])) + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms self.assertEqual( self._compress_shared(shares_private), - set([(u1, u3, private_room), (u3, u1, private_room)]), + {(u1, u3, private_room), (u3, u1, private_room)}, ) def test_initial_share_all_users(self): @@ -290,15 +386,19 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set([])) + self.assertEqual(self._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 2d5dba6464..2096ba3c91 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -20,6 +20,23 @@ from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import Connection from twisted.internet.interfaces import IOpenSSLServerConnectionCreator +from twisted.internet.ssl import Certificate, trustRootFromCertificates +from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401 +from twisted.web.iweb import IPolicyForHTTPS # noqa: F401 + + +def get_test_https_policy(): + """Get a test IPolicyForHTTPS which trusts the test CA cert + + Returns: + IPolicyForHTTPS + """ + ca_file = get_test_ca_cert_file() + with open(ca_file) as stream: + content = stream.read() + cert = Certificate.loadPEM(content) + trust_root = trustRootFromCertificates([cert]) + return BrowserLikePolicyForHTTPS(trustRoot=trust_root) def get_test_ca_cert_file(): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 71d7025264..562397cdda 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -31,14 +31,14 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from synapse.config.homeserver import HomeServerConfig -from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.util.caches.ttlcache import TTLCache from tests import unittest @@ -79,7 +79,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - self.tls_factory = ClientTLSOptionsFactory(config) + self.tls_factory = FederationPolicyForHTTPS(config) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) @@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase): FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) + # grab a hold of the TLS connection, in case it gets torn down + server_tls_connection = server_tls_protocol._tlsConnection + + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_tls_protocol.wrappedProtocol + # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) # check the SNI - server_name = server_tls_protocol._tlsConnection.get_servername() + server_name = server_tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) - # fish the test server back out of the server-side TLS protocol. - return server_tls_protocol.wrappedProtocol + return http_protocol @defer.inlineCallbacks def _make_get_request(self, uri): @@ -150,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - _check_logcontext(LoggingContext.sentinel) + _check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d @@ -710,7 +715,7 @@ class MatrixFederationAgentTests(unittest.TestCase): config = default_config("test", parse=True) # Build a new agent and WellKnownResolver with a different tls factory - tls_factory = ClientTLSOptionsFactory(config) + tls_factory = FederationPolicyForHTTPS(config) agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, @@ -1192,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): def _check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index df034ab237..babc201643 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # should have reset to the sentinel context - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) result = yield resolve_d # should have restored our context - self.assertIs(LoggingContext.current_context(), ctx) + self.assertIs(current_context(), ctx) return result diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 2b01f40a42..fff4f0cbf4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests.server import FakeTransport from tests.unittest import HomeserverTestCase def check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) @@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - check_logcontext(LoggingContext.sentinel) + check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py new file mode 100644 index 0000000000..22abf76515 --- /dev/null +++ b/tests/http/test_proxyagent.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import treq + +from twisted.internet import interfaces # noqa: F401 +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel + +from synapse.http.proxyagent import ProxyAgent + +from tests.http import TestServerTLSConnectionFactory, get_test_https_policy +from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.unittest import TestCase + +logger = logging.getLogger(__name__) + +HTTPFactory = Factory.forProtocol(HTTPChannel) + + +class MatrixFederationAgentTests(TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + def _make_connection( + self, client_factory, server_factory, ssl=False, expected_sni=None + ): + """Builds a test server, and completes the outgoing client connection + + Args: + client_factory (interfaces.IProtocolFactory): the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + server_factory (interfaces.IProtocolFactory): a factory to build the + server-side protocol + + ssl (bool): If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + + expected_sni (bytes|None): the expected SNI value + + Returns: + IProtocol: the server Protocol returned by server_factory + """ + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory) + + server_protocol = server_factory.buildProtocol(None) + + # now, tell the client protocol factory to build the client protocol, + # and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_protocol, self.reactor, client_protocol) + ) + + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) + ) + + if ssl: + http_protocol = server_protocol.wrappedProtocol + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None + + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) + + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + return http_protocol + + def test_http_request(self): + agent = ProxyAgent(self.reactor) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 80) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request(self): + agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) + + self.reactor.lookups["test.com"] = "1.2.3.4" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + _get_test_protocol_factory(), + ssl=True, + expected_sni=b"test.com", + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_http_request_via_proxy(self): + agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888") + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy(self): + agent = ProxyAgent( + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + +def _wrap_server_factory_for_tls(factory, sanlist=None): + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` + + Args: + factory (interfaces.IProtocolFactory): protocol factory to wrap + sanlist (iterable[bytes]): list of domains the cert should be valid for + + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [b"DNS:test.com"] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + + +def _get_test_protocol_factory(): + """Get a protocol Factory which will build an HTTPChannel + + Returns: + interfaces.IProtocolFactory + """ + server_factory = Factory.forProtocol(HTTPChannel) + + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + return server_factory + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py deleted file mode 100644 index 220884311c..0000000000 --- a/tests/patch_inline_callbacks.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import functools -import sys - -from twisted.internet import defer -from twisted.internet.defer import Deferred -from twisted.python.failure import Failure - - -def do_patch(): - """ - Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit - """ - - from synapse.logging.context import LoggingContext - - orig_inline_callbacks = defer.inlineCallbacks - - def new_inline_callbacks(f): - - orig = orig_inline_callbacks(f) - - @functools.wraps(f) - def wrapped(*args, **kwargs): - start_context = LoggingContext.current_context() - - try: - res = orig(*args, **kwargs) - except Exception: - if LoggingContext.current_context() != start_context: - err = "%s changed context from %s to %s on exception" % ( - f, - start_context, - LoggingContext.current_context(), - ) - print(err, file=sys.stderr) - raise Exception(err) - raise - - if not isinstance(res, Deferred) or res.called: - if LoggingContext.current_context() != start_context: - err = "%s changed context from %s to %s" % ( - f, - start_context, - LoggingContext.current_context(), - ) - # print the error to stderr because otherwise all we - # see in travis-ci is the 500 error - print(err, file=sys.stderr) - raise Exception(err) - return res - - if LoggingContext.current_context() != LoggingContext.sentinel: - err = ( - "%s returned incomplete deferred in non-sentinel context " - "%s (start was %s)" - ) % (f, LoggingContext.current_context(), start_context) - print(err, file=sys.stderr) - raise Exception(err) - - def check_ctx(r): - if LoggingContext.current_context() != start_context: - err = "%s completion of %s changed context from %s to %s" % ( - "Failure" if isinstance(r, Failure) else "Success", - f, - start_context, - LoggingContext.current_context(), - ) - print(err, file=sys.stderr) - raise Exception(err) - return r - - res.addBoth(check_ctx) - return res - - return wrapped - - defer.inlineCallbacks = new_inline_callbacks diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 358b593cd4..83032cc9ea 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -163,8 +163,9 @@ class EmailPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -173,8 +174,9 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -190,7 +192,8 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 8ce6bb62da..baf9c785f4 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase): config = self.default_config() config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, simple_http_client=m) + hs = self.setup_test_homeserver(config=config, proxied_http_client=m) return hs @@ -102,8 +102,9 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -112,8 +113,9 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -130,8 +132,9 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -149,7 +152,8 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py new file mode 100644 index 0000000000..9ae6a87d7b --- /dev/null +++ b/tests/push/test_push_rule_evaluator.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEvent +from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent + +from tests import unittest + + +class PushRuleEvaluatorTestCase(unittest.TestCase): + def setUp(self): + event = FrozenEvent( + { + "event_id": "$event_id", + "type": "m.room.history_visibility", + "sender": "@user:test", + "state_key": "", + "room_id": "@room:test", + "content": {"body": "foo bar baz"}, + }, + RoomVersions.V1, + ) + room_member_count = 0 + sender_power_level = 0 + power_levels = {} + self.evaluator = PushRuleEvaluatorForEvent( + event, room_member_count, sender_power_level, power_levels + ) + + def test_display_name(self): + """Check for a matching display name in the body of the event.""" + condition = { + "kind": "contains_display_name", + } + + # Blank names are skipped. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "")) + + # Check a display name that doesn't match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found")) + + # Check a display name which matches. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo")) + + # A display name that matches, but not a full word does not result in a match. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba")) + + # A display name should not be interpreted as a regular expression. + self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]")) + + # A display name with spaces should work fine. + self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar")) diff --git a/tests/replication/_base.py b/tests/replication/_base.py new file mode 100644 index 0000000000..9d4f0bbe44 --- /dev/null +++ b/tests/replication/_base.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, List, Optional, Tuple + +import attr + +from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.task import LoopingCall +from twisted.web.http import HTTPChannel + +from synapse.app.generic_worker import ( + GenericWorkerReplicationHandler, + GenericWorkerServer, +) +from synapse.http.site import SynapseRequest +from synapse.replication.http import streams +from synapse.replication.tcp.handler import ReplicationCommandHandler +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.server import FakeTransport + +logger = logging.getLogger(__name__) + + +class BaseStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + + servlets = [ + streams.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # build a replication server + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() + self.server = server_factory.buildProtocol(None) + + # Make a new HomeServer object for the worker + self.reactor.lookups["testserv"] = "1.2.3.4" + self.worker_hs = self.setup_test_homeserver( + http_client=None, + homeserverToUse=GenericWorkerServer, + config=self._get_worker_hs_config(), + reactor=self.reactor, + ) + + # Since we use sqlite in memory databases we need to make sure the + # databases objects are the same. + self.worker_hs.get_datastore().db = hs.get_datastore().db + + self.test_handler = self._build_replication_data_handler() + self.worker_hs.replication_data_handler = self.test_handler + + repl_handler = ReplicationCommandHandler(self.worker_hs) + self.client = ClientReplicationStreamProtocol( + self.worker_hs, "client", "test", clock, repl_handler, + ) + + self._client_transport = None + self._server_transport = None + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.generic_worker" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def _build_replication_data_handler(self): + return TestReplicationDataHandler(self.worker_hs) + + def reconnect(self): + if self._client_transport: + self.client.close() + + if self._server_transport: + self.server.close() + + self._client_transport = FakeTransport(self.server, self.reactor) + self.client.makeConnection(self._client_transport) + + self._server_transport = FakeTransport(self.client, self.reactor) + self.server.makeConnection(self._server_transport) + + def disconnect(self): + if self._client_transport: + self._client_transport = None + self.client.close() + + if self._server_transport: + self._server_transport = None + self.server.close() + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump(0.1) + + def handle_http_replication_attempt(self) -> SynapseRequest: + """Asserts that a connection attempt was made to the master HS on the + HTTP replication port, then proxies it to the master HS object to be + handled. + + Returns: + The request object received by master HS. + """ + + # We should have an outbound connection attempt. + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8765) + + # Set up client side protocol + client_protocol = client_factory.buildProtocol(None) + + request_factory = OneShotRequestFactory() + + # Set up the server side protocol + channel = _PushHTTPChannel(self.reactor) + channel.requestFactory = request_factory + channel.site = self.site + + # Connect client to server and vice versa. + client_to_server_transport = FakeTransport( + channel, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, channel + ) + channel.makeConnection(server_to_client_transport) + + # The request will now be processed by `self.site` and the response + # streamed back. + self.reactor.advance(0) + + # We tear down the connection so it doesn't get reused without our + # knowledge. + server_to_client_transport.loseConnection() + client_to_server_transport.loseConnection() + + return request_factory.request + + def assert_request_is_get_repl_stream_updates( + self, request: SynapseRequest, stream_name: str + ): + """Asserts that the given request is a HTTP replication request for + fetching updates for given stream. + """ + + self.assertRegex( + request.path, + br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" + % (stream_name.encode("ascii"),), + ) + + self.assertEqual(request.method, b"GET") + + +class TestReplicationDataHandler(GenericWorkerReplicationHandler): + """Drop-in for ReplicationDataHandler which just collects RDATA rows""" + + def __init__(self, hs: HomeServer): + super().__init__(hs) + + # list of received (stream_name, token, row) tuples + self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] + + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + for r in rows: + self.received_rdata_rows.append((stream_name, token, r)) + + +@attr.s() +class OneShotRequestFactory: + """A simple request factory that generates a single `SynapseRequest` and + stores it for future use. Can only be used once. + """ + + request = attr.ib(default=None) + + def __call__(self, *args, **kwargs): + assert self.request is None + + self.request = SynapseRequest(*args, **kwargs) + return self.request + + +class _PushHTTPChannel(HTTPChannel): + """A HTTPChannel that wraps pull producers to push producers. + + This is a hack to get around the fact that HTTPChannel transparently wraps a + pull producer (which is what Synapse uses to reply to requests) with + `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush` + uses the standard reactor rather than letting us use our test reactor, which + makes it very hard to test. + """ + + def __init__(self, reactor: IReactorTime): + super().__init__() + self.reactor = reactor + + self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] + + def registerProducer(self, producer, streaming): + # Convert pull producers to push producer. + if not streaming: + self._pull_to_push_producer = _PullToPushProducer( + self.reactor, producer, self + ) + producer = self._pull_to_push_producer + + super().registerProducer(producer, True) + + def unregisterProducer(self): + if self._pull_to_push_producer: + # We need to manually stop the _PullToPushProducer. + self._pull_to_push_producer.stop() + + +class _PullToPushProducer: + """A push producer that wraps a pull producer. + """ + + def __init__( + self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer + ): + self._clock = Clock(reactor) + self._producer = producer + self._consumer = consumer + + # While running we use a looping call with a zero delay to call + # resumeProducing on given producer. + self._looping_call = None # type: Optional[LoopingCall] + + # We start writing next reactor tick. + self._start_loop() + + def _start_loop(self): + """Start the looping call to + """ + + if not self._looping_call: + # Start a looping call which runs every tick. + self._looping_call = self._clock.looping_call(self._run_once, 0) + + def stop(self): + """Stops calling resumeProducing. + """ + if self._looping_call: + self._looping_call.stop() + self._looping_call = None + + def pauseProducing(self): + """Implements IPushProducer + """ + self.stop() + + def resumeProducing(self): + """Implements IPushProducer + """ + self._start_loop() + + def stopProducing(self): + """Implements IPushProducer + """ + self.stop() + self._producer.stopProducing() + + def _run_once(self): + """Calls resumeProducing on producer once. + """ + + try: + self._producer.resumeProducing() + except Exception: + logger.exception("Failed to call resumeProducing") + try: + self._consumer.unregisterProducer() + except Exception: + pass + + self.stopProducing() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 104349cdbd..56497b8476 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,52 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock -from synapse.replication.tcp.client import ( - ReplicationClientFactory, - ReplicationClientHandler, -) -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from tests.replication._base import BaseStreamTestCase -from tests import unittest -from tests.server import FakeTransport - -class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): +class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - "blue", - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), - ) - - hs.get_ratelimiter().can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(federation_client=Mock()) return hs def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) - self.master_store = self.hs.get_datastore() - self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) - self.event_id = 0 - - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer - - self.replication_handler = ReplicationClientHandler(self.slaved_store) - client_factory = ReplicationClientFactory( - self.hs, "client_name", self.replication_handler - ) - - server = server_factory.buildProtocol(None) - client = client_factory.buildProtocol(None) - - client.makeConnection(FakeTransport(server, reactor)) + self.reconnect() - self.server_to_client_transport = FakeTransport(client, reactor) - server.makeConnection(self.server_to_client_transport) + self.master_store = hs.get_datastore() + self.slaved_store = self.worker_hs.get_datastore() + self.storage = hs.get_storage() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index a368117b43..1a88c7fb80 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -15,18 +15,20 @@ import logging from canonicaljson import encode_canonical_json -from synapse.events import FrozenEvent, _EventInternalMetadata -from synapse.events.snapshot import EventContext +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser +from tests.server import FakeTransport + from ._base import BaseSlavedStoreTestCase -USER_ID = "@feeling:blue" -USER_ID_2 = "@bright:blue" +USER_ID = "@feeling:test" +USER_ID_2 = "@bright:test" OUTLIER = {"outlier": True} -ROOM_ID = "!room:blue" +ROOM_ID = "!room:test" logger = logging.getLogger(__name__) @@ -58,6 +60,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] return super(SlavedEventStoreTestCase, self).setUp() + def prepare(self, *args, **kwargs): + super().prepare(*args, **kwargs) + + self.get_success( + self.master_store.store_room( + ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1, + ) + ) + def tearDown(self): [unpatch() for unpatch in self.unpatches] @@ -90,7 +101,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction - redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) self.check("get_event", [msg.event_id], redacted) def test_backfilled_redactions(self): @@ -110,18 +123,20 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction - redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + redacted = make_event_from_dict( + msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() + ) self.check("get_event", [msg.event_id], redacted) def test_invites(self): self.persist(type="m.room.create", key="", creator=USER_ID) - self.check("get_invited_rooms_for_user", [USER_ID_2], []) + self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") self.replicate() self.check( - "get_invited_rooms_for_user", + "get_invited_rooms_for_local_user", [USER_ID_2], [ RoomsForUser( @@ -225,7 +240,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) # limit the replication rate - repl_transport = self.server_to_client_transport + repl_transport = self._server_transport + assert isinstance(repl_transport, FakeTransport) repl_transport.autoflush = False # build the join and message events and persist them in the same batch. @@ -234,7 +250,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) + self.get_success( + self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + ) self.replicate() event_source = RoomEventSource(self.hs) @@ -290,10 +308,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.master_store.persist_events([(event, context)], backfilled=True) + self.storage.persistence.persist_events( + [(event, context)], backfilled=True + ) ) else: - self.get_success(self.master_store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -304,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.message", key=None, internal={}, - state=None, depth=None, prev_events=[], auth_events=[], @@ -341,18 +360,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if redacts is not None: event_dict["redacts"] = redacts - event = FrozenEvent(event_dict, internal_metadata_dict=internal) + event = make_event_from_dict(event_dict, internal_metadata_dict=internal) self.event_id += 1 - - if state is not None: - state_ids = {key: e.event_id for key, e in state.items()} - context = EventContext.with_state( - state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids - ) - else: - state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context(event)) + state_handler = self.hs.get_state_handler() + context = self.get_success(state_handler.compute_event_context(event)) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py deleted file mode 100644 index ce3835ae6a..0000000000 --- a/tests/replication/tcp/streams/_base.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.replication.tcp.commands import ReplicateCommand -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory - -from tests import unittest -from tests.server import FakeTransport - - -class BaseStreamTestCase(unittest.HomeserverTestCase): - """Base class for tests of the replication streams""" - - def prepare(self, reactor, clock, hs): - # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer - server = server_factory.buildProtocol(None) - - # build a replication client, with a dummy handler - self.test_handler = TestReplicationClientHandler() - self.client = ClientReplicationStreamProtocol( - "client", "test", clock, self.test_handler - ) - - # wire them together - self.client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(self.client, reactor)) - - def replicate(self): - """Tell the master side of replication that something has happened, and then - wait for the replication to occur. - """ - self.streamer.on_notifier_poke() - self.pump(0.1) - - def replicate_stream(self, stream, token="NOW"): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream, token)) - - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" - - def __init__(self): - self.received_rdata_rows = [] - - def get_streams_to_replicate(self): - return {} - - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - def on_rdata(self, stream_name, token, rows): - for r in rows: - self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py new file mode 100644 index 0000000000..6a5116dd2a --- /dev/null +++ b/tests/replication/tcp/streams/test_account_data.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.tcp.streams._base import ( + _STREAM_UPDATE_TARGET_ROW_COUNT, + AccountDataStream, +) + +from tests.replication._base import BaseStreamTestCase + + +class AccountDataStreamTestCase(BaseStreamTestCase): + def test_update_function_room_account_data_limit(self): + """Test replication with many room account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success( + store.add_account_data_to_room("test_user", "test_room", update, {}) + ) + updates.append(update) + + # also one global update + self.get_success(store.add_account_data_for_user("test_user", "m.global", {})) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertEqual(row.room_id, "test_room") + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.global") + self.assertIsNone(row.room_id) + + self.assertEqual([], received_rows) + + def test_update_function_global_account_data_limit(self): + """Test replication with many global account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success(store.add_account_data_for_user("test_user", update, {})) + updates.append(update) + + # also one per-room update + self.get_success( + store.add_account_data_to_room("test_user", "test_room", "m.per_room", {}) + ) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertIsNone(row.room_id) + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.per_room") + self.assertEqual(row.room_id, "test_room") + + self.assertEqual([], received_rows) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py new file mode 100644 index 0000000000..51bf0ef4e9 --- /dev/null +++ b/tests/replication/tcp/streams/test_events.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, + EventsStreamRow, +) +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseStreamTestCase +from tests.test_utils.event_injection import inject_event, inject_member_event + + +class EventsStreamTestCase(BaseStreamTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + + self.reconnect() + + self.room_id = self.helper.create_room_as(tok=self.user_tok) + self.test_handler.received_rdata_rows.clear() + + def test_update_function_event_row_limit(self): + """Test replication with many non-state events + + Checks that all events are correctly replicated when there are lots of + event rows to be replicated. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # generate lots of non-state events. We inject them using inject_event + # so that they are not send out over replication until we call self.replicate(). + events = [ + self._inject_test_event() + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1) + ] + + # also one state event + state_event = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + + for event in events: + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + self.assertEqual([], received_rows) + + def test_update_function_huge_state_change(self): + """Test replication with many state events + + Ensures that all events are correctly replicated when there are lots of + state change rows to be replicated. + """ + + # we want to generate lots of state changes at a single stream ID. + # + # We do this by having two branches in the DAG. On one, we have a moderator + # which that generates lots of state; on the other, we de-op the moderator, + # thus invalidating all the state. + + OTHER_USER = "@other_user:localhost" + + # have the user join + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"][OTHER_USER] = 50 + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [ + self._inject_state_event(sender=OTHER_USER) + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + ] + + self.replicate() + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # a state event which doesn't get rolled back, to check that the state + # before the huge update comes through ok + state1 = self._inject_state_event() + + # roll back all the state by de-modding the user + prev_events = fork_point + pls["users"][OTHER_USER] = 0 + pl_event = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + + # one more bit of state that doesn't get rolled back + state2 = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + # + # we expect: + # + # - two rows for state1 + # - the PL event row, plus state rows for the PL event and each + # of the states that got reverted. + # - two rows for state2 + + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + + # first check the first two rows, which should be state1 + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state1.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state1.event_id) + + # now the last two rows, which should be state2 + stream_name, token, row = received_rows.pop(-2) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state2.event_id) + + stream_name, token, row = received_rows.pop(-1) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state2.event_id) + + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for stream_name, token, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + def test_update_function_state_row_limit(self): + """Test replication with many state events over several stream ids. + """ + + # we want to generate lots of state changes, but for this test, we want to + # spread out the state changes over a few stream IDs. + # + # We do this by having two branches in the DAG. On one, we have four moderators, + # each of which that generates lots of state; on the other, we de-op the users, + # thus invalidating all the state. + + NUM_USERS = 4 + STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1 + + user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)] + + # have the users join + for u in user_ids: + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"].update({u: 50 for u in user_ids}) + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [] # type: List[EventBase] + for user in user_ids: + events.extend( + self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) + ) + + self.replicate() + + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # now roll back all that state by de-modding the users + prev_events = fork_point + pl_events = [] + for u in user_ids: + pls["users"][u] = 0 + e = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + prev_events = [e.event_id] + pl_events.append(e) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + self.assertGreaterEqual(len(received_rows), len(events)) + for i in range(NUM_USERS): + # for each user, we expect the PL event row, followed by state rows for + # the PL event and each of the states that got reverted. + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_events[i].event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for j in range(STATES_PER_USER + 1): + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_events[i].event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + self.assertEqual([], received_rows) + + event_count = 0 + + def _inject_test_event( + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + ) -> EventBase: + if sender is None: + sender = self.user_id + + if body is None: + body = "event %i" % (self.event_count,) + self.event_count += 1 + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) + + def _inject_state_event( + self, + body: Optional[str] = None, + state_key: Optional[str] = None, + sender: Optional[str] = None, + ) -> EventBase: + if sender is None: + sender = self.user_id + + if state_key is None: + state_key = "state_%i" % (self.event_count,) + self.event_count += 1 + + if body is None: + body = "state event %s" % (state_key,) + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py new file mode 100644 index 0000000000..2babea4e3e --- /dev/null +++ b/tests/replication/tcp/streams/test_federation.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.federation.send_queue import EduRow +from synapse.replication.tcp.streams.federation import FederationStream + +from tests.replication._base import BaseStreamTestCase + + +class FederationStreamTestCase(BaseStreamTestCase): + def _get_worker_hs_config(self) -> dict: + # enable federation sending on the worker + config = super()._get_worker_hs_config() + # TODO: make it so we don't need both of these + config["send_federation"] = True + config["worker_app"] = "synapse.app.federation_sender" + return config + + def test_catchup(self): + """Basic test of catchup on reconnect + + Makes sure that updates sent while we are offline are received later. + """ + fed_sender = self.hs.get_federation_sender() + received_rows = self.test_handler.received_rdata_rows + + fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"}) + + self.reconnect() + self.reactor.advance(0) + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual(received_rows, []) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "federation") + + # we should have received an update row + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test_edu") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"a": "b"}) + + self.assertEqual(received_rows, []) + + # additional updates should be transferred without an HTTP hit + fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"}) + self.reactor.advance(0) + # there should be no http hit + self.assertEqual(len(self.reactor.tcpClients), 0) + # ... but we should have a row + self.assertEqual(len(received_rows), 1) + + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test1") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"c": "d"}) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index d5a99f6caa..56b062ecc1 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,35 +12,73 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams._base import ReceiptsStreamRow -from tests.replication.tcp.streams._base import BaseStreamTestCase +# type: ignore + +from mock import Mock + +from synapse.replication.tcp.streams._base import ReceiptsStream + +from tests.replication._base import BaseStreamTestCase USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_receipt(self): - # make the client subscribe to the receipts stream - self.replicate_stream("receipts", "NOW") + self.reconnect() # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) self.replicate() # there should be one RDATA command - rdata_rows = self.test_handler.received_rdata_rows + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStreamRow - self.assertEqual(ROOM_ID, row.room_id) + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) - self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual("$event:blue", row.event_id) self.assertEqual({"a": 1}, row.data) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + self.get_success( + self.hs.get_datastore().insert_receipt( + "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + ) + ) + self.replicate() + + # Nothing should have happened as we are disconnected + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now have caught up and get the missing data + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") + self.assertEqual(token, 3) + self.assertEqual(1, len(rdata_rows)) + + row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + self.assertEqual("!room2:blue", row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual("$event2:foo", row.event_id) + self.assertEqual({"a": 2}, row.data) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py new file mode 100644 index 0000000000..fd62b26356 --- /dev/null +++ b/tests/replication/tcp/streams/test_typing.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.handlers.typing import RoomMember +from synapse.replication.tcp.streams import TypingStream + +from tests.replication._base import BaseStreamTestCase + +USER_ID = "@feeling:blue" + + +class TypingStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + + def test_typing(self): + typing = self.hs.get_typing_handler() + + room_id = "!bar:blue" + + self.reconnect() + + typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) + + self.reactor.advance(0) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] # type: TypingStream.TypingStreamRow + self.assertEqual(room_id, row.room_id) + self.assertEqual([USER_ID], row.user_ids) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + typing._push_update(member=RoomMember(room_id, USER_ID), typing=False) + + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") + + # The from token should be the token from the last RDATA we got. + self.assertEqual(int(request.args[b"from_token"][0]), token) + + self.test_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] + self.assertEqual(room_id, row.room_id) + self.assertEqual([], row.user_ids) diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py new file mode 100644 index 0000000000..60c10a441a --- /dev/null +++ b/tests/replication/tcp/test_commands.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.commands import ( + RdataCommand, + ReplicateCommand, + parse_command_from_line, +) + +from tests.unittest import TestCase + + +class ParseCommandTestCase(TestCase): + def test_parse_one_word_command(self): + line = "REPLICATE" + cmd = parse_command_from_line(line) + self.assertIsInstance(cmd, ReplicateCommand) + + def test_parse_rdata(self): + line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + cmd = parse_command_from_line(line) + assert isinstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.instance_name, "master") + self.assertEqual(cmd.token, 6287863) + + def test_parse_rdata_batch(self): + line = 'RDATA presence master batch ["@foo:example.com", "online"]' + cmd = parse_command_from_line(line) + assert isinstance(cmd, RdataCommand) + self.assertEqual(cmd.stream_name, "presence") + self.assertEqual(cmd.instance_name, "master") + self.assertIsNone(cmd.token) diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py new file mode 100644 index 0000000000..d1c15caeb0 --- /dev/null +++ b/tests/replication/tcp/test_remote_server_up.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from twisted.internet.interfaces import IProtocol +from twisted.test.proto_helpers import StringTransport + +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests.unittest import HomeserverTestCase + + +class RemoteServerUpTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.factory = ReplicationStreamProtocolFactory(hs) + + def _make_client(self) -> Tuple[IProtocol, StringTransport]: + """Create a new direct TCP replication connection + """ + + proto = self.factory.buildProtocol(("127.0.0.1", 0)) + transport = StringTransport() + proto.makeConnection(transport) + + # We can safely ignore the commands received during connection. + self.pump() + transport.clear() + + return proto, transport + + def test_relay(self): + """Test that Synapse will relay REMOTE_SERVER_UP commands to all + other connections, but not the one that sent it. + """ + + proto1, transport1 = self._make_client() + + # We shouldn't receive an echo. + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + self.pump() + self.assertEqual(transport1.value(), b"") + + # But we should see an echo if we connect another client + proto2, transport2 = self._make_client() + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + + self.pump() + self.assertEqual(transport1.value(), b"") + self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n") diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py new file mode 100644 index 0000000000..5448d9f0dc --- /dev/null +++ b/tests/replication/test_federation_ack.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock + +from synapse.app.generic_worker import GenericWorkerServer +from synapse.replication.tcp.commands import FederationAckCommand +from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.streams.federation import FederationStream + +from tests.unittest import HomeserverTestCase + + +class FederationAckTestCase(HomeserverTestCase): + def default_config(self) -> dict: + config = super().default_config() + config["worker_app"] = "synapse.app.federation_sender" + config["send_federation"] = True + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + return hs + + def test_federation_ack_sent(self): + """A FEDERATION_ACK should be sent back after each RDATA federation + + This test checks that the federation sender is correctly sending back + FEDERATION_ACK messages. The test works by spinning up a federation_sender + worker server, and then fishing out its ReplicationCommandHandler. We wire + the RCH up to a mock connection (so that we can observe the command being sent) + and then poke in an RDATA row. + + XXX: it might be nice to do this by pretending to be a synapse master worker + (or a redis server), and having the worker connect to us via a mocked-up TCP + transport, rather than assuming that the implementation has a + ReplicationCommandHandler. + """ + rch = self.hs.get_tcp_replication() + + # wire up the ReplicationCommandHandler to a mock connection + mock_connection = mock.Mock(spec=AbstractConnection) + rch.new_connection(mock_connection) + + # tell it it received an RDATA row + self.get_success( + rch.on_rdata( + "federation", + "master", + token=10, + rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])], + ) + ) + + # now check that the FEDERATION_ACK was sent + mock_connection.send_command.assert_called_once() + cmd = mock_connection.send_command.call_args[0][0] + assert isinstance(cmd, FederationAckCommand) + self.assertEqual(cmd.token, 10) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 5877bb2133..977615ebef 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib -import hmac import json +import os +import urllib.parse +from binascii import unhexlify from mock import Mock +from twisted.internet.defer import Deferred + import synapse.rest.admin -from synapse.api.constants import UserTypes from synapse.http.server import JsonResource +from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import groups from tests import unittest @@ -47,517 +50,440 @@ class VersionTestCase(unittest.HomeserverTestCase): ) -class UserRegisterTestCase(unittest.HomeserverTestCase): - - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - - def make_homeserver(self, reactor, clock): - - self.url = "/_matrix/client/r0/admin/register" +class DeleteGroupTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + groups.register_servlets, + ] - self.registration_handler = Mock() - self.identity_handler = Mock() - self.login_handler = Mock() - self.device_handler = Mock() - self.device_handler.check_device_registered = Mock(return_value="FAKE") + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() - self.datastore = Mock(return_value=Mock()) - self.datastore.get_current_state_deltas = Mock(return_value=[]) + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") - self.secrets = Mock() + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") - self.hs = self.setup_test_homeserver() + def test_delete_group(self): + # Create a new group + request, channel = self.make_request( + "POST", + "/create_group".encode("ascii"), + access_token=self.admin_user_tok, + content={"localpart": "test"}, + ) - self.hs.config.registration_shared_secret = "shared" + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.hs.get_media_repository = Mock() - self.hs.get_deactivate_account_handler = Mock() + group_id = channel.json_body["group_id"] - return self.hs + self._check_group(group_id, expect_code=200) - def test_disabled(self): - """ - If there is no shared secret, registration through this method will be - prevented. - """ - self.hs.config.registration_shared_secret = None + # Invite/join another user - request, channel = self.make_request("POST", self.url, b"{}") + url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) + request, channel = self.make_request( + "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} + ) self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "Shared secret registration is not enabled", channel.json_body["error"] + url = "/groups/%s/self/accept_invite" % (group_id,) + request, channel = self.make_request( + "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) - - def test_get_nonce(self): - """ - Calling GET on the endpoint will return a randomised nonce, using the - homeserver's secrets provider. - """ - secrets = Mock() - secrets.token_hex = Mock(return_value="abcd") - - self.hs.get_secrets = Mock(return_value=secrets) - - request, channel = self.make_request("GET", self.url) self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(channel.json_body, {"nonce": "abcd"}) - - def test_expired_nonce(self): - """ - Calling GET on the endpoint will return a randomised nonce, which will - only last for SALT_TIMEOUT (60s). - """ - request, channel = self.make_request("GET", self.url) - self.render(request) - nonce = channel.json_body["nonce"] + # Check other user knows they're in the group + self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) + self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - # 59 seconds - self.reactor.advance(59) + # Now delete the group + url = "/admin/delete_group/" + group_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + access_token=self.admin_user_tok, + content={"localpart": "test"}, + ) - body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("username must be specified", channel.json_body["error"]) - - # 61 seconds - self.reactor.advance(2) - - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + # Check group returns 404 + self._check_group(group_id, expect_code=404) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("unrecognised nonce", channel.json_body["error"]) + # Check users don't think they're in the group + self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) + self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - def test_register_incorrect_nonce(self): - """ - Only the provided nonce can be used, as it's checked in the MAC. + def _check_group(self, group_id, expect_code): + """Assert that trying to fetch the given group results in the given + HTTP status code """ - request, channel = self.make_request("GET", self.url) - self.render(request) - nonce = channel.json_body["nonce"] - - want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() - - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("HMAC incorrect", channel.json_body["error"]) + url = "/groups/%s/profile" % (group_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) - def test_register_correct_nonce(self): - """ - When the correct nonce is provided, and the right key is provided, the - user is registered. - """ - request, channel = self.make_request("GET", self.url) self.render(request) - nonce = channel.json_body["nonce"] - - want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update( - nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) - want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } + def _get_groups_user_is_in(self, access_token): + """Returns the list of groups the user is in (given their access token) + """ + request, channel = self.make_request( + "GET", "/joined_groups".encode("ascii"), access_token=access_token ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["user_id"]) - - def test_nonce_reuse(self): - """ - A valid unrecognised nonce. - """ - request, channel = self.make_request("GET", self.url) - self.render(request) - nonce = channel.json_body["nonce"] - - want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() - - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("@bob:test", channel.json_body["user_id"]) - - # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("unrecognised nonce", channel.json_body["error"]) - def test_missing_parts(self): - """ - Synapse will complain if you don't give nonce, username, password, and - mac. Admin and user_types are optional. Additional checks are done for length - and type. - """ - - def nonce(): - request, channel = self.make_request("GET", self.url) - self.render(request) - return channel.json_body["nonce"] - - # - # Nonce check - # - - # Must be present - body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("nonce must be specified", channel.json_body["error"]) + return channel.json_body["groups"] - # - # Username checks - # - # Must be present - body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) +class QuarantineMediaTestCase(unittest.HomeserverTestCase): + """Test /quarantine_media admin API. + """ - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("username must be specified", channel.json_body["error"]) + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + room.register_servlets, + ] - # Must be a string - body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.hs = hs + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + self.image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid username", channel.json_body["error"]) + def make_homeserver(self, reactor, clock): - # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + self.fetches = [] - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid username", channel.json_body["error"]) + def get_file(destination, path, output_stream, args=None, max_size=None): + """ + Returns tuple[int,dict,str,int] of file length, response headers, + absolute URI, and response code. + """ - # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + def write_to(r): + data, response = r + output_stream.write(data) + return response - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid username", channel.json_body["error"]) + d = Deferred() + d.addCallback(write_to) + self.fetches.append((d, destination, path, args)) + return make_deferred_yieldable(d) - # - # Password checks - # + client = Mock() + client.get_file = get_file - # Must be present - body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("password must be specified", channel.json_body["error"]) + config = self.default_config() + config["media_store_path"] = self.media_store_path + config["thumbnail_requirements"] = {} + config["max_image_pixels"] = 2000000 - # Must be a string - body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + provider_config = { + "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + config["media_storage_providers"] = [provider_config] - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid password", channel.json_body["error"]) + hs = self.setup_test_homeserver(config=config, http_client=client) - # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + return hs - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid password", channel.json_body["error"]) + def test_quarantine_media_requires_admin(self): + self.register_user("nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("nonadmin", "pass") - # Super long - body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) + # Attempt quarantine media APIs as non-admin + url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=non_admin_user_tok, + ) self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid password", channel.json_body["error"]) - - # - # user_type check - # + # Expect a forbidden error + self.assertEqual( + 403, + int(channel.result["code"]), + msg="Expected forbidden on quarantining media as a non-admin", + ) - # Invalid user_type - body = json.dumps( - { - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid", - } + # And the roomID/userID endpoint + url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" + request, channel = self.make_request( + "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("Invalid user type", channel.json_body["error"]) - - -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") + # Expect a forbidden error + self.assertEqual( + 403, + int(channel.result["code"]), + msg="Expected forbidden on quarantining media as a non-admin", + ) - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") + def test_quarantine_media_by_id(self): + self.register_user("id_admin", "pass", admin=True) + admin_user_tok = self.login("id_admin", "pass") - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + self.register_user("id_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("id_nonadmin", "pass") - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None + # Upload some media into the room + response = self.helper.upload_media( + self.upload_resource, self.image_data, tok=admin_user_tok + ) - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + # Extract media ID from the response + server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name, media_id = server_name_and_media_id.split("/") - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) + # Attempt to access the media + request, channel = self.make_request( + "GET", + server_name_and_media_id, + shorthand=False, + access_token=non_admin_user_tok, + ) + request.render(self.download_resource) + self.pump(1.0) - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" + # Should be successful + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 + # Quarantine the media + url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( + urllib.parse.quote(server_name), + urllib.parse.quote(media_id), ) + request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id + # Attempt to access the media request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, + "GET", + server_name_and_media_id, + shorthand=False, + access_token=admin_user_tok, ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + request.render(self.download_resource) + self.pump(1.0) - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_name_and_media_id + ), + ) - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ + def test_quarantine_all_media_in_room(self, override_url_template=None): + self.register_user("room_admin", "pass", admin=True) + admin_user_tok = self.login("room_admin", "pass") - self.event_creation_handler._block_events_without_consent_error = None + non_admin_user = self.register_user("room_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("room_nonadmin", "pass") - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok) + self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok) - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + # Extract mxcs + mxc_1 = response_1["content_uri"] + mxc_2 = response_2["content_uri"] + + # Send it into the room + self.helper.send_event( + room_id, + "m.room.message", + content={"body": "image-1", "msgtype": "m.image", "url": mxc_1}, + txn_id="111", + tok=non_admin_user_tok, + ) + self.helper.send_event( + room_id, + "m.room.message", + content={"body": "image-2", "msgtype": "m.image", "url": mxc_2}, + txn_id="222", + tok=non_admin_user_tok, + ) - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) + # Quarantine all media in the room + if override_url_template: + url = override_url_template % urllib.parse.quote(room_id) + else: + url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( + room_id + ) + request, channel = self.make_request("POST", url, access_token=admin_user_tok,) + self.render(request) + self.pump(1.0) + self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual( + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", + ) - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ + # Convert mxc URLs to server/media_id strings + server_and_media_id_1 = mxc_1[6:] + server_and_media_id_2 = mxc_2[6:] - url = "rooms/%s/initialSync" % (room_id,) + # Test that we cannot download any of the media anymore request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok + "GET", + server_and_media_id_1, + shorthand=False, + access_token=non_admin_user_tok, ) - self.render(request) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_1 + ), ) - url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok + "GET", + server_and_media_id_2, + shorthand=False, + access_token=non_admin_user_tok, ) - self.render(request) + request.render(self.download_resource) + self.pump(1.0) + + # Should be quarantined self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_2 + ), ) + def test_quaraantine_all_media_in_room_deprecated_api_path(self): + # Perform the above test with the deprecated API path + self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s") -class DeleteGroupTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - groups.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") + def test_quarantine_all_media_by_user(self): + self.register_user("user_admin", "pass", admin=True) + admin_user_tok = self.login("user_admin", "pass") - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") + non_admin_user = self.register_user("user_nonadmin", "pass", admin=False) + non_admin_user_tok = self.login("user_nonadmin", "pass") - def test_delete_group(self): - # Create a new group - request, channel = self.make_request( - "POST", - "/create_group".encode("ascii"), - access_token=self.admin_user_tok, - content={"localpart": "test"}, + # Upload some media + response_1 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok + ) + response_2 = self.helper.upload_media( + self.upload_resource, self.image_data, tok=non_admin_user_tok ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - group_id = channel.json_body["group_id"] - - self._check_group(group_id, expect_code=200) - - # Invite/join another user + # Extract media IDs + server_and_media_id_1 = response_1["content_uri"][6:] + server_and_media_id_2 = response_2["content_uri"][6:] - url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) - request, channel = self.make_request( - "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} + # Quarantine all media by this user + url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( + non_admin_user ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} + "POST", url.encode("ascii"), access_token=admin_user_tok, ) self.render(request) + self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Check other user knows they're in the group - self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) - self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - - # Now delete the group - url = "/admin/delete_group/" + group_id - request, channel = self.make_request( - "POST", - url.encode("ascii"), - access_token=self.admin_user_tok, - content={"localpart": "test"}, + self.assertEqual( + json.loads(channel.result["body"].decode("utf-8")), + {"num_quarantined": 2}, + "Expected 2 quarantined items", ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Check group returns 404 - self._check_group(group_id, expect_code=404) - - # Check users don't think they're in the group - self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) - self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token)) - - def _check_group(self, group_id, expect_code): - """Assert that trying to fetch the given group results in the given - HTTP status code - """ - - url = "/groups/%s/profile" % (group_id,) + # Attempt to access each piece of media request, channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok + "GET", + server_and_media_id_1, + shorthand=False, + access_token=non_admin_user_tok, ) + request.render(self.download_resource) + self.pump(1.0) - self.render(request) + # Should be quarantined self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_1, + ), ) - def _get_groups_user_is_in(self, access_token): - """Returns the list of groups the user is in (given their access token) - """ + # Attempt to access each piece of media request, channel = self.make_request( - "GET", "/joined_groups".encode("ascii"), access_token=access_token + "GET", + server_and_media_id_2, + shorthand=False, + access_token=non_admin_user_tok, ) + request.render(self.download_resource) + self.pump(1.0) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - return channel.json_body["groups"] + # Should be quarantined + self.assertEqual( + 404, + int(channel.code), + msg=( + "Expected to receive a 404 on accessing quarantined media: %s" + % server_and_media_id_2 + ), + ) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py new file mode 100644 index 0000000000..faa7f381a9 --- /dev/null +++ b/tests/rest/admin/test_device.py @@ -0,0 +1,541 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login + +from tests import unittest + + +class DeviceRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.other_user_device_id = res[0]["device_id"] + + self.url = "/_synapse/admin/v2/users/%s/devices/%s" % ( + urllib.parse.quote(self.other_user), + self.other_user_device_id, + ) + + def test_no_auth(self): + """ + Try to get a device of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + request, channel = self.make_request("PUT", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + request, channel = self.make_request("DELETE", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + request, channel = self.make_request( + "PUT", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + request, channel = self.make_request( + "DELETE", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = ( + "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" + % self.other_user_device_id + ) + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + request, channel = self.make_request( + "PUT", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = ( + "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" + % self.other_user_device_id + ) + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + request, channel = self.make_request( + "PUT", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_unknown_device(self): + """ + Tests that a lookup for a device that does not exist returns either 404 or 200. + """ + url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( + self.other_user + ) + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + request, channel = self.make_request( + "PUT", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + # Delete unknown device returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) + + def test_update_device_too_long_display_name(self): + """ + Update a device with a display name that is invalid (too long). + """ + # Set iniital display name. + update = {"display_name": "new display"} + self.get_success( + self.handler.update_device( + self.other_user, self.other_user_device_id, update + ) + ) + + # Request to update a device display name with a new value that is longer than allowed. + update = { + "display_name": "a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) + } + + body = json.dumps(update) + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # Ensure the display name was not updated. + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("new display", channel.json_body["display_name"]) + + def test_update_no_display_name(self): + """ + Tests that a update for a device without JSON returns a 200 + """ + # Set iniital display name. + update = {"display_name": "new display"} + self.get_success( + self.handler.update_device( + self.other_user, self.other_user_device_id, update + ) + ) + + request, channel = self.make_request( + "PUT", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Ensure the display name was not updated. + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("new display", channel.json_body["display_name"]) + + def test_update_display_name(self): + """ + Tests a normal successful update of display name + """ + # Set new display_name + body = json.dumps({"display_name": "new displayname"}) + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check new display_name + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("new displayname", channel.json_body["display_name"]) + + def test_get_device(self): + """ + Tests that a normal lookup for a device is successfully + """ + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + # Check that all fields are available + self.assertIn("user_id", channel.json_body) + self.assertIn("device_id", channel.json_body) + self.assertIn("display_name", channel.json_body) + self.assertIn("last_seen_ip", channel.json_body) + self.assertIn("last_seen_ts", channel.json_body) + + def test_delete_device(self): + """ + Tests that a remove of a device is successfully + """ + # Count number of devies of an user. + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + number_devices = len(res) + self.assertEqual(1, number_devices) + + # Delete device + request, channel = self.make_request( + "DELETE", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Ensure that the number of devices is decreased + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.assertEqual(number_devices - 1, len(res)) + + +class DevicesRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v2/users/%s/devices" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list devices of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v2/users/@unknown_person:test/devices" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_devices(self): + """ + Tests that a normal lookup for devices is successfully + """ + # Create devices + number_devices = 5 + for n in range(number_devices): + self.login("user", "pass") + + # Get devices + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_devices, len(channel.json_body["devices"])) + self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) + # Check that all fields are available + for d in channel.json_body["devices"]: + self.assertIn("user_id", d) + self.assertIn("device_id", d) + self.assertIn("display_name", d) + self.assertIn("last_seen_ip", d) + self.assertIn("last_seen_ts", d) + + +class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v2/users/%s/delete_devices" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to delete devices of an user without authentication. + """ + request, channel = self.make_request("POST", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "POST", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" + request, channel = self.make_request( + "POST", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" + + request, channel = self.make_request( + "POST", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_unknown_devices(self): + """ + Tests that a remove of a device that does not exist returns 200. + """ + body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) + request, channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + # Delete unknown devices returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) + + def test_delete_devices(self): + """ + Tests that a remove of devices is successfully + """ + + # Create devices + number_devices = 5 + for n in range(number_devices): + self.login("user", "pass") + + # Get devices + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.assertEqual(number_devices, len(res)) + + # Create list of device IDs + device_ids = [] + for d in res: + device_ids.append(str(d["device_id"])) + + # Delete devices + body = json.dumps({"devices": device_ids}) + request, channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.assertEqual(0, len(res)) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py new file mode 100644 index 0000000000..54cd24bf64 --- /dev/null +++ b/tests/rest/admin/test_room.py @@ -0,0 +1,1007 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse +from typing import List, Optional + +from mock import Mock + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import directory, events, login, room + +from tests import unittest + +"""Tests admin REST events for /rooms paths.""" + + +class ShutdownRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + """ + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + room_id, body="foo", tok=self.other_user_token, expect_code=403 + ) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert there is now no longer anyone in the room + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode("ascii"), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"] + ) + + +class PurgeRoomTestCase(unittest.HomeserverTestCase): + """Test /purge_room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_purge_room(self): + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # All users have to have left the room. + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/purge_room" + request, channel = self.make_request( + "POST", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the following tables have been purged of all rows related to the room. + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", + ): + count = self.get_success( + self.store.db.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + + +class RoomTestCase(unittest.HomeserverTestCase): + """Test /room admin API. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + def test_list_rooms(self): + """Test that we can list rooms""" + # Create 3 test rooms + total_rooms = 3 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + + # Check request completed successfully + self.assertEqual(200, int(channel.code), msg=channel.json_body) + + # Check that response json body contains a "rooms" key + self.assertTrue( + "rooms" in channel.json_body, + msg="Response body does not " "contain a 'rooms' key", + ) + + # Check that 3 rooms were returned + self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body) + + # Check their room_ids match + returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]] + self.assertEqual(room_ids, returned_room_ids) + + # Check that all fields are available + for r in channel.json_body["rooms"]: + self.assertIn("name", r) + self.assertIn("canonical_alias", r) + self.assertIn("joined_members", r) + self.assertIn("joined_local_members", r) + self.assertIn("version", r) + self.assertIn("creator", r) + self.assertIn("encryption", r) + self.assertIn("federatable", r) + self.assertIn("public", r) + self.assertIn("join_rules", r) + self.assertIn("guest_access", r) + self.assertIn("history_visibility", r) + self.assertIn("state_events", r) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # Should be 0 as we aren't paginating + self.assertEqual(channel.json_body["offset"], 0) + + # Check that the prev_batch parameter is not present + self.assertNotIn("prev_batch", channel.json_body) + + # We shouldn't receive a next token here as there's no further rooms to show + self.assertNotIn("next_batch", channel.json_body) + + def test_list_rooms_pagination(self): + """Test that we can get a full list of rooms through pagination""" + # Create 5 test rooms + total_rooms = 5 + room_ids = [] + for x in range(total_rooms): + room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + room_ids.append(room_id) + + # Set the name of the rooms so we get a consistent returned ordering + for idx, room_id in enumerate(room_ids): + self.helper.send_state( + room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + returned_room_ids = [] + start = 0 + limit = 2 + + run_count = 0 + should_repeat = True + while should_repeat: + run_count += 1 + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % ( + start, + limit, + "name", + ) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + self.assertTrue("rooms" in channel.json_body) + for r in channel.json_body["rooms"]: + returned_room_ids.append(r["room_id"]) + + # Check that the correct number of total rooms was returned + self.assertEqual(channel.json_body["total_rooms"], total_rooms) + + # Check that the offset is correct + # We're only getting 2 rooms each page, so should be 2 * last run_count + self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1)) + + if run_count > 1: + # Check the value of prev_batch is correct + self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2)) + + if "next_batch" not in channel.json_body: + # We have reached the end of the list + should_repeat = False + else: + # Make another query with an updated start value + start = channel.json_body["next_batch"] + + # We should've queried the endpoint 3 times + self.assertEqual( + run_count, + 3, + msg="Should've queried 3 times for 5 rooms with limit 2 per query", + ) + + # Check that we received all of the room ids + self.assertEqual(room_ids, returned_room_ids) + + url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_correct_room_attributes(self): + """Test the correct attributes for a room are returned""" + # Create a test room + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + test_alias = "#test:test" + test_room_name = "something" + + # Have another user join the room + user_2 = self.register_user("user4", "pass") + user_tok_2 = self.login("user4", "pass") + self.helper.join(room_id, user_2, tok=user_tok_2) + + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=self.admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=self.admin_user_tok, + ) + + # Set a name for the room + self.helper.send_state( + room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + ) + + # Request the list of rooms + url = "/_synapse/admin/v1/rooms" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that only one room was returned + self.assertEqual(len(rooms), 1) + + # And that the value of the total_rooms key was correct + self.assertEqual(channel.json_body["total_rooms"], 1) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that all provided attributes are set + r = rooms[0] + self.assertEqual(room_id, r["room_id"]) + self.assertEqual(test_room_name, r["name"]) + self.assertEqual(test_alias, r["canonical_alias"]) + + def test_room_list_sort_order(self): + """Test room list sort ordering. alphabetical name versus number of members, + reversing the order, etc. + """ + + def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str): + # Create a new alias to this room + url = "/_matrix/client/r0/directory/room/%s" % ( + urllib.parse.quote(test_alias), + ) + request, channel = self.make_request( + "PUT", + url.encode("ascii"), + {"room_id": room_id}, + access_token=admin_user_tok, + ) + self.render(request) + self.assertEqual( + 200, int(channel.result["code"]), msg=channel.result["body"] + ) + + # Set this new alias as the canonical alias for this room + self.helper.send_state( + room_id, + "m.room.aliases", + {"aliases": [test_alias]}, + tok=admin_user_tok, + state_key="test", + ) + self.helper.send_state( + room_id, + "m.room.canonical_alias", + {"alias": test_alias}, + tok=admin_user_tok, + ) + + def _order_test( + order_type: str, expected_room_list: List[str], reverse: bool = False, + ): + """Request the list of rooms in a certain order. Assert that order is what + we expect + + Args: + order_type: The type of ordering to give the server + expected_room_list: The list of room_ids in the order we expect to get + back from the server + """ + # Request the list of rooms in the given order + url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) + if reverse: + url += "&dir=b" + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check for the correct total_rooms value + self.assertEqual(channel.json_body["total_rooms"], 3) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + # Check that rooms were returned in alphabetical order + returned_order = [r["room_id"] for r in rooms] + self.assertListEqual(expected_room_list, returned_order) # order is checked + + # Create 3 test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C + self.helper.send_state( + room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + ) + + # Set room canonical room aliases + _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok) + _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok) + + # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3 + user_1 = self.register_user("bob1", "pass") + user_1_tok = self.login("bob1", "pass") + self.helper.join(room_id_2, user_1, tok=user_1_tok) + + user_2 = self.register_user("bob2", "pass") + user_2_tok = self.login("bob2", "pass") + self.helper.join(room_id_3, user_2, tok=user_2_tok) + + user_3 = self.register_user("bob3", "pass") + user_3_tok = self.login("bob3", "pass") + self.helper.join(room_id_3, user_3, tok=user_3_tok) + + # Test different sort orders, with forward and reverse directions + _order_test("name", [room_id_1, room_id_2, room_id_3]) + _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) + _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) + _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) + _order_test( + "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("version", [room_id_1, room_id_2, room_id_3]) + _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("creator", [room_id_1, room_id_2, room_id_3]) + _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("encryption", [room_id_1, room_id_2, room_id_3]) + _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("federatable", [room_id_1, room_id_2, room_id_3]) + _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("public", [room_id_1, room_id_2, room_id_3]) + # Different sort order of SQlite and PostreSQL + # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) + + _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) + _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) + _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) + + _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) + _order_test( + "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True + ) + + _order_test("state_events", [room_id_3, room_id_2, room_id_1]) + _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) + + def test_search_term(self): + """Test that searching for a room works correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + def _search_test( + expected_room_id: Optional[str], + search_term: str, + expected_http_code: int = 200, + ): + """Search for a room and check that the returned room's id is a match + + Args: + expected_room_id: The room_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for room names with + expected_http_code: The expected http code for the request + """ + url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that rooms were returned + self.assertTrue("rooms" in channel.json_body) + rooms = channel.json_body["rooms"] + + # Check that the expected number of rooms were returned + expected_room_count = 1 if expected_room_id else 0 + self.assertEqual(len(rooms), expected_room_count) + self.assertEqual(channel.json_body["total_rooms"], expected_room_count) + + # Check that the offset is correct + # We're not paginating, so should be 0 + self.assertEqual(channel.json_body["offset"], 0) + + # Check that there is no `prev_batch` + self.assertNotIn("prev_batch", channel.json_body) + + # Check that there is no `next_batch` + self.assertNotIn("next_batch", channel.json_body) + + if expected_room_id: + # Check that the first returned room id is correct + r = rooms[0] + self.assertEqual(expected_room_id, r["room_id"]) + + # Perform search tests + _search_test(room_id_1, "something") + _search_test(room_id_1, "thing") + + _search_test(room_id_2, "else") + _search_test(room_id_2, "se") + + _search_test(None, "foo") + _search_test(None, "bar") + _search_test(None, "", expected_http_code=400) + + def test_single_room(self): + """Test that a single room can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertIn("room_id", channel.json_body) + self.assertIn("name", channel.json_body) + self.assertIn("canonical_alias", channel.json_body) + self.assertIn("joined_members", channel.json_body) + self.assertIn("joined_local_members", channel.json_body) + self.assertIn("version", channel.json_body) + self.assertIn("creator", channel.json_body) + self.assertIn("encryption", channel.json_body) + self.assertIn("federatable", channel.json_body) + self.assertIn("public", channel.json_body) + self.assertIn("join_rules", channel.json_body) + self.assertIn("guest_access", channel.json_body) + self.assertIn("history_visibility", channel.json_body) + self.assertIn("state_events", channel.json_body) + + self.assertEqual(room_id_1, channel.json_body["room_id"]) + + +class JoinAliasRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.second_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If a parameter is missing, return an error + """ + body = json.dumps({"unknown_parameter": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_local_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + body = json.dumps({"user_id": "@unknown:test"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_remote_user(self): + """ + Check that only local user can join rooms. + """ + body = json.dumps({"user_id": "@not:exist.bla"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "This endpoint can only be used with local users", + channel.json_body["error"], + ) + + def test_room_does_not_exist(self): + """ + Check that unknown rooms/server return error 404. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/!unknown:test" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No known servers", channel.json_body["error"]) + + def test_room_is_not_valid(self): + """ + Check that invalid room names, return an error 400. + """ + body = json.dumps({"user_id": self.second_user_id}) + url = "/_synapse/admin/v1/join/invalidroom" + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "invalidroom was not legal room ID or room alias", + channel.json_body["error"], + ) + + def test_join_public_room(self): + """ + Test joining a local user to a public room with "JoinRules.PUBLIC" + """ + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_not_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE" + when server admin is not member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_join_private_room_if_member(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is member of this room. + """ + private_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False + ) + self.helper.invite( + room=private_room_id, + src=self.creator, + targ=self.admin_user, + tok=self.creator_tok, + ) + self.helper.join( + room=private_room_id, user=self.admin_user, tok=self.admin_user_tok + ) + + # Validate if server admin is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + # Join user to room. + + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + def test_join_private_room_if_owner(self): + """ + Test joining a local user to a private room with "JoinRules.INVITE", + when server admin is owner of this room. + """ + private_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) + url = "/_synapse/admin/v1/join/{}".format(private_room_id) + body = json.dumps({"user_id": self.second_user_id}) + + request, channel = self.make_request( + "POST", + url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["room_id"]) + + # Validate if user is a member of the room + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + ) + self.render(request) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py new file mode 100644 index 0000000000..cca5f548e6 --- /dev/null +++ b/tests/rest/admin/test_user.py @@ -0,0 +1,947 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import json +import urllib.parse + +from mock import Mock + +import synapse.rest.admin +from synapse.api.constants import UserTypes +from synapse.api.errors import HttpResponseException, ResourceLimitError +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import sync + +from tests import unittest +from tests.unittest import override_config + + +class UserRegisterTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + + self.url = "/_matrix/client/r0/admin/register" + + self.registration_handler = Mock() + self.identity_handler = Mock() + self.login_handler = Mock() + self.device_handler = Mock() + self.device_handler.check_device_registered = Mock(return_value="FAKE") + + self.datastore = Mock(return_value=Mock()) + self.datastore.get_current_state_deltas = Mock(return_value=(0, [])) + + self.secrets = Mock() + + self.hs = self.setup_test_homeserver() + + self.hs.config.registration_shared_secret = "shared" + + self.hs.get_media_repository = Mock() + self.hs.get_deactivate_account_handler = Mock() + + return self.hs + + def test_disabled(self): + """ + If there is no shared secret, registration through this method will be + prevented. + """ + self.hs.config.registration_shared_secret = None + + request, channel = self.make_request("POST", self.url, b"{}") + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + "Shared secret registration is not enabled", channel.json_body["error"] + ) + + def test_get_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, using the + homeserver's secrets provider. + """ + secrets = Mock() + secrets.token_hex = Mock(return_value="abcd") + + self.hs.get_secrets = Mock(return_value=secrets) + + request, channel = self.make_request("GET", self.url) + self.render(request) + + self.assertEqual(channel.json_body, {"nonce": "abcd"}) + + def test_expired_nonce(self): + """ + Calling GET on the endpoint will return a randomised nonce, which will + only last for SALT_TIMEOUT (60s). + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + # 59 seconds + self.reactor.advance(59) + + body = json.dumps({"nonce": nonce}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("username must be specified", channel.json_body["error"]) + + # 61 seconds + self.reactor.advance(2) + + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) + + def test_register_incorrect_nonce(self): + """ + Only the provided nonce can be used, as it's checked in the MAC. + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("HMAC incorrect", channel.json_body["error"]) + + def test_register_correct_nonce(self): + """ + When the correct nonce is provided, and the right key is provided, the + user is registered. + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update( + nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" + ) + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + def test_nonce_reuse(self): + """ + A valid unrecognised nonce. + """ + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + # Now, try and reuse it + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) + + def test_missing_parts(self): + """ + Synapse will complain if you don't give nonce, username, password, and + mac. Admin and user_types are optional. Additional checks are done for length + and type. + """ + + def nonce(): + request, channel = self.make_request("GET", self.url) + self.render(request) + return channel.json_body["nonce"] + + # + # Nonce check + # + + # Must be present + body = json.dumps({}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("nonce must be specified", channel.json_body["error"]) + + # + # Username checks + # + + # Must be present + body = json.dumps({"nonce": nonce()}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("username must be specified", channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": 1234}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid username", channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid username", channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid username", channel.json_body["error"]) + + # + # Password checks + # + + # Must be present + body = json.dumps({"nonce": nonce(), "username": "a"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("password must be specified", channel.json_body["error"]) + + # Must be a string + body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid password", channel.json_body["error"]) + + # Must not have null bytes + body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid password", channel.json_body["error"]) + + # Super long + body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid password", channel.json_body["error"]) + + # + # user_type check + # + + # Invalid user_type + body = json.dumps( + { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Invalid user type", channel.json_body["error"]) + + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} + ) + def test_register_mau_limit_reached(self): + """ + Check we can register a user via the shared secret registration API + even if the MAU limit is reached. + """ + handler = self.hs.get_registration_handler() + store = self.hs.get_datastore() + + # Set monthly active users to the limit + store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + # Check that the blocking of monthly active users is working as expected + # The registration of a new user fails due to the limit + self.get_failure( + handler.register_user(localpart="local_part"), ResourceLimitError + ) + + # Register new user with admin API + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update( + nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" + ) + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + + +class UsersListTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + url = "/_synapse/admin/v2/users" + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.register_user("user1", "pass1", admin=False) + self.register_user("user2", "pass2", admin=False) + + def test_no_auth(self): + """ + Try to list users without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) + + def test_all_users(self): + """ + List all users, including deactivated users. + """ + request, channel = self.make_request( + "GET", + self.url + "?deactivated=true", + b"{}", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(3, len(channel.json_body["users"])) + self.assertEqual(3, channel.json_body["total"]) + + +class UserRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( + self.other_user + ) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + url = "/_synapse/admin/v2/users/@bob:test" + + request, channel = self.make_request( + "GET", url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + request, channel = self.make_request( + "PUT", url, access_token=self.other_user_token, content=b"{}", + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("You are not a server admin", channel.json_body["error"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + + request, channel = self.make_request( + "GET", + "/_synapse/admin/v2/users/@unknown_person:test", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) + + def test_create_server_admin(self): + """ + Check that a new admin user is created successfully. + """ + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user (server admin) + body = json.dumps( + { + "password": "abc123", + "admin": True, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": None, + } + ) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + def test_create_user(self): + """ + Check that a new regular user is created successfully. + """ + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps( + { + "password": "abc123", + "admin": False, + "displayname": "Bob's name", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("Bob's name", channel.json_body["displayname"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual(False, channel.json_body["is_guest"]) + self.assertEqual(False, channel.json_body["deactivated"]) + + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} + ) + def test_create_user_mau_limit_reached_active_admin(self): + """ + Check that an admin can register a new user via the admin API + even if the MAU limit is reached. + Admin user was active before creating user. + """ + + handler = self.hs.get_registration_handler() + + # Sync to set admin user to active + # before limit of monthly active users is reached + request, channel = self.make_request( + "GET", "/sync", access_token=self.admin_user_tok + ) + self.render(request) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"] + ) + + # Set monthly active users to the limit + self.store.get_monthly_active_count = Mock( + return_value=self.hs.config.max_mau_value + ) + # Check that the blocking of monthly active users is working as expected + # The registration of a new user fails due to the limit + self.get_failure( + handler.register_user(localpart="local_part"), ResourceLimitError + ) + + # Register new user with admin API + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps({"password": "abc123", "admin": False}) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["admin"]) + + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} + ) + def test_create_user_mau_limit_reached_passive_admin(self): + """ + Check that an admin can register a new user via the admin API + even if the MAU limit is reached. + Admin user was not active before creating user. + """ + + handler = self.hs.get_registration_handler() + + # Set monthly active users to the limit + self.store.get_monthly_active_count = Mock( + return_value=self.hs.config.max_mau_value + ) + # Check that the blocking of monthly active users is working as expected + # The registration of a new user fails due to the limit + self.get_failure( + handler.register_user(localpart="local_part"), ResourceLimitError + ) + + # Register new user with admin API + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps({"password": "abc123", "admin": False}) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + # Admin user is not blocked by mau anymore + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["admin"]) + + @override_config( + { + "email": { + "enable_notifs": True, + "notif_for_new_users": True, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) + def test_create_user_email_notif_for_new_users(self): + """ + Check that a new regular user is created successfully and + got an email pusher. + """ + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps( + { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + + pushers = self.get_success( + self.store.get_pushers_by({"user_name": "@bob:test"}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 1) + self.assertEqual("@bob:test", pushers[0]["user_name"]) + + @override_config( + { + "email": { + "enable_notifs": False, + "notif_for_new_users": False, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) + def test_create_user_email_no_notif_for_new_users(self): + """ + Check that a new regular user is created successfully and + got not an email pusher. + """ + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps( + { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + + pushers = self.get_success( + self.store.get_pushers_by({"user_name": "@bob:test"}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + + def test_set_password(self): + """ + Test setting a new password for another user. + """ + + # Change password + body = json.dumps({"password": "hahaha"}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_set_displayname(self): + """ + Test setting the displayname of another user. + """ + + # Modify user + body = json.dumps({"displayname": "foobar"}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("foobar", channel.json_body["displayname"]) + + def test_set_threepid(self): + """ + Test setting threepid for an other user. + """ + + # Delete old and add new threepid to user + body = json.dumps( + {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} + ) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + + def test_deactivate_user(self): + """ + Test deactivating another user. + """ + + # Deactivate user + body = json.dumps({"deactivated": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + # the user is deactivated, the threepid will be deleted + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + def test_set_user_as_admin(self): + """ + Test setting the admin flag on a user. + """ + + # Set a user as an admin + body = json.dumps({"admin": True}) + + request, channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_other_user, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["admin"]) + + def test_accidental_deactivation_prevention(self): + """ + Ensure an account can't accidentally be deactivated by using a str value + for the deactivated body parameter + """ + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps({"password": "abc123"}) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + + # Get user + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + self.assertEqual(0, channel.json_body["deactivated"]) + + # Change password (and use a str for deactivate instead of a bool) + body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Check user is not deactivated + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("bob", channel.json_body["displayname"]) + + # Ensure they're still alive + self.assertEqual(0, channel.json_body["deactivated"]) diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 0000000000..5e9c07ebf3 --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py new file mode 100644 index 0000000000..913ea3c98e --- /dev/null +++ b/tests/rest/client/test_power_levels.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync + +from tests.unittest import HomeserverTestCase + + +class PowerLevelsTestCase(HomeserverTestCase): + """Tests that power levels are enforced in various situations""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + # register a room admin, moderator and regular user + self.admin_user_id = self.register_user("admin", "pass") + self.admin_access_token = self.login("admin", "pass") + self.mod_user_id = self.register_user("mod", "pass") + self.mod_access_token = self.login("mod", "pass") + self.user_user_id = self.register_user("user", "pass") + self.user_access_token = self.login("user", "pass") + + # Create a room + self.room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # Invite the other users + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.mod_user_id, + ) + self.helper.invite( + room=self.room_id, + src=self.admin_user_id, + tok=self.admin_access_token, + targ=self.user_user_id, + ) + + # Make the other users join the room + self.helper.join( + room=self.room_id, user=self.mod_user_id, tok=self.mod_access_token + ) + self.helper.join( + room=self.room_id, user=self.user_user_id, tok=self.user_access_token + ) + + # Mod the mod + room_power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.admin_access_token, + ) + + # Update existing power levels with mod at PL50 + room_power_levels["users"].update({self.mod_user_id: 50}) + + self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + ) + + def test_non_admins_cannot_enable_room_encryption(self): + # have the mod try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_send_server_acl(self): + # have the mod try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_non_admins_cannot_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the mod try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.mod_access_token, + expect_code=403, # expect failure + ) + + # have the user try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.user_access_token, + expect_code=403, # expect failure + ) + + def test_admins_can_enable_room_encryption(self): + # have the admin try to enable room encryption + self.helper.send_state( + self.room_id, + "m.room.encryption", + {"algorithm": "m.megolm.v1.aes-sha2"}, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_send_server_acl(self): + # have the admin try to send a server ACL + self.helper.send_state( + self.room_id, + "m.room.server_acl", + { + "allow": ["*"], + "allow_ip_literals": False, + "deny": ["*.evil.com", "evil.com"], + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) + + def test_admins_can_tombstone_room(self): + # Create another room that will serve as our "upgraded room" + self.upgraded_room_id = self.helper.create_room_as( + self.admin_user_id, tok=self.admin_access_token + ) + + # have the admin try to send a tombstone event + self.helper.send_state( + self.room_id, + "m.room.tombstone", + { + "body": "This room has been replaced", + "replacement_room": self.upgraded_room_id, + }, + tok=self.admin_access_token, + expect_code=200, # expect success + ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644 index 0000000000..95475bb651 --- /dev/null +++ b/tests/rest/client/test_retention.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 4}, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_hour_ms}, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": lifetime}, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + storage = self.hs.get_storage() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success(store.get_event(resp.get("event_id")))) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success(store.get_event(valid_event_id))) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success( + filter_events_for_client(storage, self.user_id, events) + ) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Get the create event to, later, check that we can still access it. + message_handler = self.hs.get_message_handler() + create_event = self.get_success( + message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + ) + + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + # Check that we can still access state events that were sent before the event that + # has been purged. + self.get_event(room_id, create_event.event_id) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 35}, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual( + first_event.get("content", {}).get("body"), "1", first_event + ) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index a3d7e3c046..171632e195 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -2,7 +2,7 @@ from mock import Mock, call from twisted.internet import defer, reactor -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock @@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test(): with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) - self.assertIs(LoggingContext.current_context(), c1) + self.assertIs(current_context(), c1) self.assertEqual(res, "yay") # run the test twice in parallel d = defer.gatherResults([test(), test()]) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) yield d - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks def test_does_not_cache_exceptions(self): @@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_does_not_cache_failures(self): @@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_cleans_up(self): diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index f340b7e851..f75520877f 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,7 +15,7 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import events, login, room @@ -40,17 +40,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(config=config) hs.get_handlers().federation_handler = Mock() return hs - def prepare(self, hs, reactor, clock): + def prepare(self, reactor, clock, hs): # register an account self.user_id = self.register_user("sid1", "pass") @@ -134,3 +130,30 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # someone else set topic, expect 6 (join,send,topic,join,send,topic) pass + + +class GetEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, hs, reactor, clock): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def test_get_event_via_events(self): + resp = self.helper.send(self.room_id, tok=self.token) + event_id = resp["event_id"] + + request, channel = self.make_request( + "GET", "/events/" + event_id, access_token=self.token, + ) + self.render(request) + self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index eae5411325..9033f09fd2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,7 +1,13 @@ import json +import time +import urllib.parse + +from mock import Mock + +import jwt import synapse.rest.admin -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha.account import WhoamiRestServlet @@ -17,12 +23,12 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, + logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), ] def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] @@ -31,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_address(self): - self.hs.config.rc_login_address.burst_count = 5 - self.hs.config.rc_login_address.per_second = 0.17 - # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -73,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_account(self): - self.hs.config.rc_login_account.burst_count = 5 - self.hs.config.rc_login_account.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -112,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.config.rc_login_failed_attempts.burst_count = 5 - self.hs.config.rc_login_failed_attempts.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -252,3 +288,370 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEquals(channel.code, 200, channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard logout this session + request, channel = self.make_request( + b"POST", "/logout", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + request, channel = self.make_request( + b"GET", TEST_URL, access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard log out all of the user's sessions + request, channel = self.make_request( + b"POST", "/logout/all", access_token=access_token + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +class CASTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["cas_config"] = { + "enabled": True, + "server_url": "https://fake.test", + "service_url": "https://matrix.goodserver.com:8448", + } + + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return ( + """ + <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'> + <cas:authenticationSuccess> + <cas:user>%s</cas:user> + <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket> + <cas:proxies> + <cas:proxy>https://proxy2/pgtUrl</cas:proxy> + <cas:proxy>https://proxy1/pgtUrl</cas:proxy> + </cas:proxies> + </cas:authenticationSuccess> + </cas:serviceResponse> + """ + % cas_user_id + ) + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, proxied_http_client=mocked_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Test that the response is HTML. + self.assertEqual(channel.code, 200) + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url + """ + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account(self.user_id, False) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) + + +class JWTTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + jwt_secret = "secret" + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_secret + self.hs.config.jwt_algorithm = "HS256" + return self.hs + + def jwt_encode(self, token, secret=jwt_secret): + return jwt.encode(token, secret, "HS256").decode("ascii") + + def jwt_login(self, *args): + params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + return channel + + def test_login_jwt_valid_registered(self): + self.register_user("kermit", "monkey") + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_valid_unregistered(self): + channel = self.jwt_login({"sub": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, "notsecret") + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_jwt_expired(self): + channel = self.jwt_login({"sub": "frog", "exp": 864000}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "JWT expired") + + def test_login_jwt_not_before(self): + now = int(time.time()) + channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_no_sub(self): + channel = self.jwt_login({"username": "root"}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_no_token(self): + params = json.dumps({"type": "m.login.jwt"}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") + + +# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use +# RSS256, with a public key configured in synapse as "jwt_secret", and tokens +# signed by the private key. +class JWTPubKeyTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + ] + + # This key's pubkey is used as the jwt_secret setting of synapse. Valid + # tokens are signed by this and validated using the pubkey. It is generated + # with `openssl genrsa 512` (not a secure way to generate real keys, but + # good enough for tests!) + jwt_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", + "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", + "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", + "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", + "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", + "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", + "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", + "-----END RSA PRIVATE KEY-----", + ] + ) + + # Generated with `openssl rsa -in foo.key -pubout`, with the the above + # private key placed in foo.key (jwt_privatekey). + jwt_pubkey = "\n".join( + [ + "-----BEGIN PUBLIC KEY-----", + "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", + "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", + "-----END PUBLIC KEY-----", + ] + ) + + # This key is used to sign tokens that shouldn't be accepted by synapse. + # Generated just like jwt_privatekey. + bad_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", + "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", + "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", + "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", + "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", + "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", + "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", + "-----END RSA PRIVATE KEY-----", + ] + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_pubkey + self.hs.config.jwt_algorithm = "RS256" + return self.hs + + def jwt_encode(self, token, secret=jwt_privatekey): + return jwt.encode(token, secret, "RS256").decode("ascii") + + def jwt_login(self, *args): + params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + return channel + + def test_login_jwt_valid(self): + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 66c2b68707..0fdff79aa7 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -15,6 +15,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) hs.presence_handler = Mock() + hs.presence_handler.set_state.return_value = defer.succeed(None) return hs diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 140d8b3772..8df58b4a63 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase): ] ) + self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( + Mock() + ) + hs = yield setup_test_homeserver( self.addCleanup, "test", @@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) + return defer.succeed(synapse.types.create_requester(myid)) hs.get_auth().get_user_by_req = _get_user_by_req @@ -229,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -301,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index fe741637f5..4886bbb401 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,14 +20,18 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import Membership -from synapse.rest.client.v1 import login, profile, room +from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.handlers.pagination import PurgeStatus +from synapse.rest.client.v1 import directory, login, profile, room +from synapse.rest.client.v2_alpha import account +from synapse.types import JsonDict, RoomAlias +from synapse.util.stringutils import random_string from tests import unittest @@ -40,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) - self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) @@ -484,6 +485,15 @@ class RoomsCreateTestCase(RoomBase): self.render(request) self.assertEquals(400, channel.code) + def test_post_room_invitees_invalid_mxid(self): + # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 + # Note the trailing space in the MXID here! + request, channel = self.make_request( + "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' + ) + self.render(request) + self.assertEquals(400, channel.code) + class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ @@ -802,6 +812,78 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ @@ -998,3 +1080,899 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 2, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 4, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 1, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id + + +class ContextTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + account.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.room_id = self.helper.create_room_as( + self.user_id, tok=self.tok, is_public=False + ) + + self.other_user_id = self.register_user("user2", "password") + self.other_tok = self.login("user2", "password") + + self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) + + def test_erased_sender(self): + """Test that an erasure request results in the requester's events being hidden + from any new member of the room. + """ + + # Send a bunch of events in the room. + + self.helper.send(self.room_id, "message 1", tok=self.tok) + self.helper.send(self.room_id, "message 2", tok=self.tok) + event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] + self.helper.send(self.room_id, "message 4", tok=self.tok) + self.helper.send(self.room_id, "message 5", tok=self.tok) + + # Check that we can still see the messages before the erasure request. + + request, channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertEqual( + events_before[0].get("content", {}).get("body"), + "message 2", + events_before[0], + ) + self.assertEqual( + events_before[1].get("content", {}).get("body"), + "message 1", + events_before[1], + ) + + self.assertEqual( + channel.json_body["event"].get("content", {}).get("body"), + "message 3", + channel.json_body["event"], + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertEqual( + events_after[0].get("content", {}).get("body"), + "message 4", + events_after[0], + ) + self.assertEqual( + events_after[1].get("content", {}).get("body"), + "message 5", + events_after[1], + ) + + # Deactivate the first account and erase the user's data. + + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account(self.user_id, erase_data=True) + ) + + # Invite another user in the room. This is needed because messages will be + # pruned only if the user wasn't a member of the room when the messages were + # sent. + + invited_user_id = self.register_user("user3", "password") + invited_tok = self.login("user3", "password") + + self.helper.invite( + self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok + ) + self.helper.join(self.room_id, invited_user_id, tok=invited_tok) + + # Check that a user that joined the room after the erasure request can't see + # the messages anymore. + + request, channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=invited_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) + self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) + + self.assertDictEqual( + channel.json_body["event"].get("content"), {}, channel.json_body["event"] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) + self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + + +class RoomAliasListTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + def test_no_aliases(self): + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(res["aliases"], []) + + def test_not_in_room(self): + self.register_user("user", "test") + user_tok = self.login("user", "test") + res = self._get_aliases(user_tok, expected_code=403) + self.assertEqual(res["errcode"], "M_FORBIDDEN") + + def test_admin_user(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.register_user("user", "test", admin=True) + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def test_with_aliases(self): + alias1 = self._random_alias() + alias2 = self._random_alias() + + self._set_alias_via_directory(alias1) + self._set_alias_via_directory(alias2) + + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(set(res["aliases"]), {alias1, alias2}) + + def test_peekable_room(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "world_readable"}, + tok=self.room_owner_tok, + ) + + self.register_user("user", "test") + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases" + % (self.room_id,), + access_token=access_token, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + if expected_code == 200: + self.assertIsInstance(res["aliases"], list) + return res + + def _random_alias(self) -> str: + return RoomAlias(random_string(5), self.hs.hostname).to_string() + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 30fb77bac8..18260bb90e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,7 +16,7 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() def get_user_by_access_token(token=None, allow_guest=False): @@ -109,7 +103,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index cdded88b7f..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,9 +18,12 @@ import json import time +from typing import Any, Dict, Optional import attr +from twisted.web.resource import Resource + from synapse.api.constants import Membership from tests.server import make_request, render @@ -33,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" @@ -106,13 +112,22 @@ class RestHelper(object): self.auth_user_id = temp_id def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) if body is None: body = "body_text_here" - path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) content = {"msgtype": "m.text", "body": body} + + return self.send_event( + room_id, "m.room.message", content, txn_id, tok, expect_code + ) + + def send_event( + self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 + ): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) if tok: path = path + "?access_token=%s" % tok @@ -128,7 +143,34 @@ class RestHelper(object): return channel.json_body - def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""): + def _read_write_state( + self, + room_id: str, + event_type: str, + body: Optional[Dict[str, Any]], + tok: str, + expect_code: int = 200, + state_key: str = "", + method: str = "GET", + ) -> Dict: + """Read or write some state from a given room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + If None, the request to the server will have an empty body + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + method: "GET" or "PUT" for reading or writing state, respectively + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( room_id, event_type, @@ -137,9 +179,13 @@ class RestHelper(object): if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8") - ) + # Set request body if provided + content = b"" + if body is not None: + content = json.dumps(body).encode("utf8") + + request, channel = make_request(self.hs.get_reactor(), method, path, content) + render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( @@ -148,3 +194,94 @@ class RestHelper(object): ) return channel.json_body + + def get_state( + self, + room_id: str, + event_type: str, + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Gets some state from a room + + Args: + room_id: + event_type: The type of state event + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, None, tok, expect_code, state_key, method="GET" + ) + + def send_state( + self, + room_id: str, + event_type: str, + body: Dict[str, Any], + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Set some state in a room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, body, tok, expect_code, state_key, method="PUT" + ) + + def upload_media( + self, + resource: Resource, + image_data: bytes, + tok: str, + filename: str = "test.png", + expect_code: int = 200, + ) -> dict: + """Upload a piece of test media to the media repo + Args: + resource: The resource that will handle the upload request + image_data: The image data to upload + tok: The user token to use during the upload + filename: The filename of the media to be uploaded + expect_code: The return code to expect from attempting to upload the media + """ + image_length = len(image_data) + path = "/_matrix/media/r0/upload?filename=%s" % (filename,) + request, channel = make_request( + self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok + ) + request.requestHeaders.addRawHeader( + b"Content-Length", str(image_length).encode("UTF-8") + ) + request.render(resource) + self.hs.get_reactor().pump([100]) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 920de41de4..3ab611f618 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -23,8 +23,9 @@ from email.parser import Parser import pkg_resources import synapse.rest.admin -from synapse.api.constants import LoginType -from synapse.rest.client.v1 import login +from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from tests import unittest @@ -45,7 +46,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): self.email_attempts.append(msg) return @@ -178,6 +179,22 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) + @unittest.override_config({"request_token_inhibit_3pid_errors": True}) + def test_password_reset_bad_email_inhibit_error(self): + """Test that triggering a password reset with an email address that isn't bound + to an account doesn't leak the lack of binding for that address if configured + that way. + """ + self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertIsNotNone(session_id) + def _request_token(self, email, client_secret): request, channel = self.make_request( "POST", @@ -244,16 +261,72 @@ class DeactivateTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, account.register_servlets, + room.register_servlets, ] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver() - return hs + self.hs = self.setup_test_homeserver() + return self.hs def test_deactivate_account(self): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") + self.deactivate(user_id, tok) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + request, channel = self.make_request("GET", "account/whoami") + self.render(request) + self.assertEqual(request.code, 401) + + @unittest.INFO + def test_pending_invites(self): + """Tests that deactivating a user rejects every pending invite for them.""" + store = self.hs.get_datastore() + + inviter_id = self.register_user("inviter", "test") + inviter_tok = self.login("inviter", "test") + + invitee_id = self.register_user("invitee", "test") + invitee_tok = self.login("invitee", "test") + + # Make @inviter:test invite @invitee:test in a new room. + room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) + self.helper.invite( + room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok + ) + + # Make sure the invite is here. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 1, pending_invites) + self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) + + # Deactivate @invitee:test. + self.deactivate(invitee_id, invitee_tok) + + # Check that the invite isn't there anymore. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 0, pending_invites) + + # Check that the membership of @invitee:test in the room is now "leave". + memberships = self.get_success( + store.get_rooms_for_local_user_where_membership_is( + invitee_id, [Membership.LEAVE] + ) + ) + self.assertEqual(len(memberships), 1, memberships) + self.assertEqual(memberships[0].room_id, room_id, memberships) + + def deactivate(self, user_id, tok): request_data = json.dumps( { "auth": { @@ -270,12 +343,303 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(request.code, 200) - store = self.hs.get_datastore() - # Check that the user has been marked as deactivated. - self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test adding an email to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) - # Check that this access token has been invalidated. - request, channel = self.make_request("GET", "account/whoami") self.render(request) - self.assertEqual(request.code, 401) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed + """ + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed + """ + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index b9ef46e8fb..293ccfba2b 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -12,22 +12,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import List, Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType -from synapse.rest.client.v2_alpha import auth, register +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.http.site import SynapseRequest +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import auth, devices, register +from synapse.types import JsonDict from tests import unittest +from tests.server import FakeChannel + + +class DummyRecaptchaChecker(UserInteractiveAuthChecker): + def __init__(self, hs): + super().__init__(hs) + self.recaptcha_attempts = [] + + def check_auth(self, authdict, clientip): + self.recaptcha_attempts.append((authdict, clientip)) + return succeed(True) + + +class DummyPasswordChecker(UserInteractiveAuthChecker): + def check_auth(self, authdict, clientip): + return succeed(authdict["identifier"]["user"]) class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False @@ -44,28 +63,55 @@ class FallbackAuthTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - self.recaptcha_attempts = [] + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" + request, channel = self.make_request( + "POST", "register", body + ) # type: SynapseRequest, FakeChannel + self.render(request) - def _recaptcha(authdict, clientip): - self.recaptcha_attempts.append((authdict, clientip)) - return succeed(True) + self.assertEqual(request.code, expected_response) + return channel - auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha + def recaptcha( + self, session: str, expected_post_response: int, post_session: str = None + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session - @unittest.INFO - def test_fallback_captcha(self): + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) # type: SynapseRequest, FakeChannel + self.render(request) + self.assertEqual(request.code, 200) request, channel = self.make_request( "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, + "auth/m.login.recaptcha/fallback/web?session=" + + post_session + + "&g-recaptcha-response=a", ) self.render(request) + self.assertEqual(request.code, expected_post_response) + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + @unittest.INFO + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec - self.assertEqual(request.code, 401) + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + # Grab the session session = channel.json_body["session"] # Assert our configured public key is being given @@ -73,39 +119,198 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" ) - request, channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step, we can then send a + # request to the register API with the session in the authdict. + channel = self.register(200, {"auth": {"session": session}}) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + # Make the initial request to register. (Later on a different password + # will be used.) + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) + + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.user_tok = self.login("test", self.user_pass) + + def get_device_ids(self) -> List[str]: + # Get the list of devices so one can be deleted. + request, channel = self.make_request( + "GET", "devices", access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel self.render(request) + + # Get the ID of the device. self.assertEqual(request.code, 200) + return [d["device_id"] for d in channel.json_body["devices"]] + def delete_device( + self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + ) -> FakeChannel: + """Delete an individual device.""" request, channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + session - + "&g-recaptcha-response=a", - ) + "DELETE", "devices/" + device, body, access_token=self.user_tok + ) # type: SynapseRequest, FakeChannel self.render(request) - self.assertEqual(request.code, 200) - # The recaptcha handler is called with the response given - self.assertEqual(len(self.recaptcha_attempts), 1) - self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a") + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) - # also complete the dummy auth + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. request, channel = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} - ) + "POST", "delete_devices", body, access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel self.render(request) - # Now we should have fufilled a complete auth flow, including - # the recaptcha fallback step, we can then send a - # request to the register API with the session in the authdict. - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session}} + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + device_id = self.get_device_ids()[0] + + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, ) - self.render(request) - self.assertEqual(channel.code, 200) - # We're given a registered user. - self.assertEqual(channel.json_body["user_id"], "@user:test") + def test_can_change_body(self): + """ + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. + self.delete_devices( + 200, + { + "devices": [device_ids[1]], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(device_ids[0], 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + self.delete_device( + device_ids[1], + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index f42a8efbf4..e0e9e94fbf 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -92,7 +92,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.render(request) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.result["code"], b"404") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py new file mode 100644 index 0000000000..c57072f50c --- /dev/null +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + request, channel = self.make_request( + "GET", "/_matrix/client/r0/password_policy" + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request, channel = self.make_request("POST", self.register_url, request_data) + self.render(request) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + self.render(request) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index ab4d7d70d0..7deaf5b24a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -25,28 +25,26 @@ import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login +from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest +from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [register.register_servlets] - - def make_homeserver(self, reactor, clock): - - self.url = b"/_matrix/client/r0/register" + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] + url = b"/_matrix/client/r0/register" - self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False - self.hs.config.allow_guest_access = True - - return self.hs + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config def test_POST_appservice_registration_valid(self): user_id = "@as_user_kermit:test" @@ -149,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): url = self.url + b"?kind=guest" request, channel = self.make_request(b"POST", url, b"{}") @@ -171,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): params = { "username": "kermit" + str(i), @@ -199,6 +193,115 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + def test_advertised_flows(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we only expect the dummy flow + self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "enable_registration_captcha": True, + "user_consent": { + "version": "1", + "template_dir": "/", + "require_at_registration": True, + }, + "account_threepid_delegates": { + "email": "https://id_server", + "msisdn": "https://id_server", + }, + } + ) + def test_advertised_flows_captcha_and_terms_and_3pids(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + self.assertCountEqual( + [ + ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], + ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], + ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], + [ + "m.login.recaptcha", + "m.login.terms", + "m.login.msisdn", + "m.login.email.identity", + ], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_advertised_flows_no_msisdn_email_required(self): + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [["m.login.email.identity"]], (f["stages"] for f in flows) + ) + + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + class AccountValidityTestCase(unittest.HomeserverTestCase): @@ -207,6 +310,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, + logout.register_servlets, account_validity.register_servlets, ] @@ -299,6 +403,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + request, channel = self.make_request(b"POST", "/logout", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -330,9 +467,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(*args, **kwargs): + async def sendmail(*args, **kwargs): self.email_attempts.append((args, kwargs)) - return config["email"] = { "enable_notifs": True, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 71895094bd..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from mock import Mock +import json import synapse.rest.admin +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -26,14 +27,12 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): user_id = "@apple:test" - servlets = [sync.register_servlets] - - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() - ) - return hs + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] def test_sync_argless(self): request, channel = self.make_request("GET", "/sync") @@ -41,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - [ - "next_batch", - "rooms", - "presence", - "account_data", - "to_device", - "device_lists", - ] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) def test_sync_presence_disabled(self): @@ -64,11 +61,149 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - ["next_batch", "rooms", "account_data", "to_device", "device_lists"] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) + ) + + +class SyncFilterTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_filter_labels(self): + """Test that we can filter by a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_sync_filter_labels(self, sync_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, ) + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + request, channel = self.make_request( + "GET", "/sync?filter=%s" % sync_filter, access_token=tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + class SyncTypingTests(unittest.HomeserverTestCase): diff --git a/synapse/rest/media/v0/__init__.py b/tests/rest/key/__init__.py index e69de29bb2..e69de29bb2 100644 --- a/synapse/rest/media/v0/__init__.py +++ b/tests/rest/key/__init__.py diff --git a/tests/rest/key/v2/__init__.py b/tests/rest/key/v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/rest/key/v2/__init__.py diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py new file mode 100644 index 0000000000..99eb477149 --- /dev/null +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -0,0 +1,257 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import urllib.parse +from io import BytesIO, StringIO + +from mock import Mock + +import signedjson.key +from canonicaljson import encode_canonical_json +from nacl.signing import SigningKey +from signedjson.sign import sign_json + +from twisted.web.resource import NoResource + +from synapse.crypto.keyring import PerspectivesKeyFetcher +from synapse.http.site import SynapseRequest +from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.storage.keys import FetchKeyResult +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.server import FakeChannel, wait_until_result +from tests.utils import default_config + + +class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.http_client = Mock() + return self.setup_test_homeserver(http_client=self.http_client) + + def create_test_json_resource(self): + return create_resource_tree( + {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + ) + + def expect_outgoing_key_request( + self, server_name: str, signing_key: SigningKey + ) -> None: + """ + Tell the mock http client to expect an outgoing GET request for the given key + """ + + def get_json(destination, path, ignore_backoff=False, **kwargs): + self.assertTrue(ignore_backoff) + self.assertEqual(destination, server_name) + key_id = "%s:%s" % (signing_key.alg, signing_key.version) + self.assertEqual( + path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),) + ) + + response = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": 200 * 1000, + "verify_keys": { + key_id: { + "key": signedjson.key.encode_verify_key_base64( + signing_key.verify_key + ) + } + }, + } + sign_json(response, server_name, signing_key) + return response + + self.http_client.get_json.side_effect = get_json + + +class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): + def make_notary_request(self, server_name: str, key_id: str) -> dict: + """Send a GET request to the test server requesting the given key. + + Checks that the response is a 200 and returns the decoded json body. + """ + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/key/v2/query/%s/%s" + % (server_name.encode("utf-8"), key_id.encode("utf-8")), + b"1.1", + ) + wait_until_result(self.reactor, req) + self.assertEqual(channel.code, 200) + resp = channel.json_body + return resp + + def test_get_key(self): + """Fetch a remote key""" + SERVER_NAME = "remote.server" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(SERVER_NAME, testkey) + + resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + self.assertEqual(len(keys[0]["verify_keys"]), 1) + + # it should be signed by both the origin server and the notary + self.assertIn(SERVER_NAME, keys[0]["signatures"]) + self.assertIn(self.hs.hostname, keys[0]["signatures"]) + + def test_get_own_key(self): + """Fetch our own key""" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(self.hs.hostname, testkey) + + resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + # it should be signed by both itself, and the notary signing key + sigs = keys[0]["signatures"] + self.assertEqual(len(sigs), 1) + self.assertIn(self.hs.hostname, sigs) + oursigs = sigs[self.hs.hostname] + self.assertEqual(len(oursigs), 2) + + # the requested key should be present in the verify_keys section + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + + +class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): + """End-to-end tests of the perspectives fetch case + + The idea here is to actually wire up a PerspectivesKeyFetcher to the notary + endpoint, to check that the two implementations are compatible. + """ + + def default_config(self): + config = super().default_config() + + # replace the signing key with our own + self.hs_signing_key = signedjson.key.generate_signing_key("kssk") + strm = StringIO() + signedjson.key.write_signing_keys(strm, [self.hs_signing_key]) + config["signing_key"] = strm.getvalue() + + return config + + def prepare(self, reactor, clock, homeserver): + # make a second homeserver, configured to use the first one as a key notary + self.http_client2 = Mock() + config = default_config(name="keyclient") + config["trusted_key_servers"] = [ + { + "server_name": self.hs.hostname, + "verify_keys": { + "ed25519:%s" + % ( + self.hs_signing_key.version, + ): signedjson.key.encode_verify_key_base64( + self.hs_signing_key.verify_key + ) + }, + } + ] + self.hs2 = self.setup_test_homeserver( + http_client=self.http_client2, config=config + ) + + # wire up outbound POST /key/v2/query requests from hs2 so that they + # will be forwarded to hs1 + def post_json(destination, path, data): + self.assertEqual(destination, self.hs.hostname) + self.assertEqual( + path, "/_matrix/key/v2/query", + ) + + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel) + req.content = BytesIO(encode_canonical_json(data)) + + req.requestReceived( + b"POST", path.encode("utf-8"), b"1.1", + ) + wait_until_result(self.reactor, req) + self.assertEqual(channel.code, 200) + resp = channel.json_body + return resp + + self.http_client2.post_json.side_effect = post_json + + def test_get_key(self): + """Fetch a key belonging to a random server""" + # make up a key to be fetched. + testkey = signedjson.key.generate_signing_key("abc") + + # we expect hs1 to make a regular key request to the target server + self.expect_outgoing_key_request("targetserver", testkey) + keyid = "ed25519:%s" % (testkey.version,) + + fetcher = PerspectivesKeyFetcher(self.hs2) + d = fetcher.get_keys({"targetserver": {keyid: 1000}}) + res = self.get_success(d) + self.assertIn("targetserver", res) + keyres = res["targetserver"][keyid] + assert isinstance(keyres, FetchKeyResult) + self.assertEqual( + signedjson.key.encode_verify_key_base64(keyres.verify_key), + signedjson.key.encode_verify_key_base64(testkey.verify_key), + ) + + def test_get_notary_key(self): + """Fetch a key belonging to the notary server""" + # make up a key to be fetched. We randomise the keyid to try to get it to + # appear before the key server signing key sometimes (otherwise we bail out + # before fetching its signature) + testkey = signedjson.key.generate_signing_key(random_string(5)) + + # we expect hs1 to make a regular key request to itself + self.expect_outgoing_key_request(self.hs.hostname, testkey) + keyid = "ed25519:%s" % (testkey.version,) + + fetcher = PerspectivesKeyFetcher(self.hs2) + d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + res = self.get_success(d) + self.assertIn(self.hs.hostname, res) + keyres = res[self.hs.hostname][keyid] + assert isinstance(keyres, FetchKeyResult) + self.assertEqual( + signedjson.key.encode_verify_key_base64(keyres.verify_key), + signedjson.key.encode_verify_key_base64(testkey.verify_key), + ) + + def test_get_notary_keyserver_key(self): + """Fetch the notary's keyserver key""" + # we expect hs1 to make a regular key request to itself + self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key) + keyid = "ed25519:%s" % (self.hs_signing_key.version,) + + fetcher = PerspectivesKeyFetcher(self.hs2) + d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + res = self.get_success(d) + self.assertIn(self.hs.hostname, res) + keyres = res[self.hs.hostname][keyid] + assert isinstance(keyres, FetchKeyResult) + self.assertEqual( + signedjson.key.encode_verify_key_base64(keyres.verify_key), + signedjson.key.encode_verify_key_base64(self.hs_signing_key.verify_key), + ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index bc662b61db..1ca648ef2b 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -18,10 +18,16 @@ import os import shutil import tempfile from binascii import unhexlify +from io import BytesIO +from typing import Optional from mock import Mock from six.moves.urllib import parse +import attr +import PIL.Image as Image +from parameterized import parameterized_class + from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable @@ -94,6 +100,68 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.assertEqual(test_body, body) +@attr.s +class _TestImage: + """An image for testing thumbnailing with the expected results + + Attributes: + data: The raw image to thumbnail + content_type: The type of the image as a content type, e.g. "image/png" + extension: The extension associated with the format, e.g. ".png" + expected_cropped: The expected bytes from cropped thumbnailing, or None if + test should just check for success. + expected_scaled: The expected bytes from scaled thumbnailing, or None if + test should just check for a valid image returned. + """ + + data = attr.ib(type=bytes) + content_type = attr.ib(type=bytes) + extension = attr.ib(type=bytes) + expected_cropped = attr.ib(type=Optional[bytes]) + expected_scaled = attr.ib(type=Optional[bytes]) + + +@parameterized_class( + ("test_image",), + [ + # smol png + ( + _TestImage( + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ), + b"image/png", + b".png", + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000020000000200806" + b"000000737a7af40000001a49444154789cedc101010000008220" + b"ffaf6e484001000000ef0610200001194334ee0000000049454e" + b"44ae426082" + ), + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000d49444154789c636060606000000005" + b"0001a5f645400000000049454e44ae426082" + ), + ), + ), + # small lossless webp + ( + _TestImage( + unhexlify( + b"524946461a000000574542505650384c0d0000002f0000001007" + b"1011118888fe0700" + ), + b"image/webp", + b".webp", + None, + None, + ), + ), + ], +) class MediaRepoTests(unittest.HomeserverTestCase): hijack_auth = True @@ -149,19 +217,13 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.media_repo = hs.get_media_repository_resource() self.download_resource = self.media_repo.children[b"download"] + self.thumbnail_resource = self.media_repo.children[b"thumbnail"] - # smol png - self.end_content = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) + self.media_id = "example.com/12345" def _req(self, content_disposition): - request, channel = self.make_request( - "GET", "example.com/12345", shorthand=False - ) + request, channel = self.make_request("GET", self.media_id, shorthand=False) request.render(self.download_resource) self.pump() @@ -170,19 +232,19 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(len(self.fetches), 1) self.assertEqual(self.fetches[0][1], "example.com") self.assertEqual( - self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345" + self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id ) self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) headers = { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b"image/png"], + b"Content-Length": [b"%d" % (len(self.test_image.data))], + b"Content-Type": [self.test_image.content_type], } if content_disposition: headers[b"Content-Disposition"] = [content_disposition] self.fetches[0][0].callback( - (self.end_content, (len(self.end_content), headers)) + (self.test_image.data, (len(self.test_image.data), headers)) ) self.pump() @@ -195,12 +257,15 @@ class MediaRepoTests(unittest.HomeserverTestCase): If the filename is filename=<ascii> then Synapse will decode it as an ASCII string, and use filename= in the response. """ - channel = self._req(b"inline; filename=out.png") + channel = self._req(b"inline; filename=out" + self.test_image.extension) headers = channel.headers - self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"]) self.assertEqual( - headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"] + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) + self.assertEqual( + headers.getRawHeaders(b"Content-Disposition"), + [b"inline; filename=out" + self.test_image.extension], ) def test_disposition_filenamestar_utf8escaped(self): @@ -210,13 +275,17 @@ class MediaRepoTests(unittest.HomeserverTestCase): response. """ filename = parse.quote("\u2603".encode("utf8")).encode("ascii") - channel = self._req(b"inline; filename*=utf-8''" + filename + b".png") + channel = self._req( + b"inline; filename*=utf-8''" + filename + self.test_image.extension + ) headers = channel.headers - self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"]) + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) self.assertEqual( headers.getRawHeaders(b"Content-Disposition"), - [b"inline; filename*=utf-8''" + filename + b".png"], + [b"inline; filename*=utf-8''" + filename + self.test_image.extension], ) def test_disposition_none(self): @@ -227,5 +296,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel = self._req(None) headers = channel.headers - self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"]) + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) + + def test_thumbnail_crop(self): + self._test_thumbnail("crop", self.test_image.expected_cropped) + + def test_thumbnail_scale(self): + self._test_thumbnail("scale", self.test_image.expected_scaled) + + def _test_thumbnail(self, method, expected_body): + params = "?width=32&height=32&method=" + method + request, channel = self.make_request( + "GET", self.media_id + params, shorthand=False + ) + request.render(self.thumbnail_resource) + self.pump() + + headers = { + b"Content-Length": [b"%d" % (len(self.test_image.data))], + b"Content-Type": [self.test_image.content_type], + } + self.fetches[0][0].callback( + (self.test_image.data, (len(self.test_image.data), headers)) + ) + self.pump() + + self.assertEqual(channel.code, 200) + if expected_body is not None: + self.assertEqual( + channel.result["body"], expected_body, channel.result["body"] + ) + else: + # ensure that the result is at least some valid image + Image.open(BytesIO(channel.result["body"])) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 976652aee8..2826211f32 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -74,6 +74,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) config["url_preview_url_blacklist"] = [] + config["url_preview_accept_language"] = [ + "en-UK", + "en-US;q=0.9", + "fr;q=0.8", + "*;q=0.7", + ] self.storage_path = self.mktemp() self.media_store_path = self.mktemp() @@ -247,6 +253,41 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") + def test_overlong_title(self): + self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] + + end_content = ( + b"<html><head>" + b"<title>" + b"x" * 2000 + b"</title>" + b'<meta property="og:description" content="hi" />' + b"</head></html>" + ) + + request, channel = self.make_request( + "GET", "url_preview?url=http://matrix.org", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="windows-1251"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + res = channel.json_body + # We should only see the `og:description` field, as `title` is too long and should be stripped out + self.assertCountEqual(["og:description"], res.keys()) + def test_ipaddr(self): """ IP addresses can be previewed directly. @@ -472,3 +513,52 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) + + def test_accept_language_config_option(self): + """ + Accept-Language header is sent to the remote server + """ + self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] + + # Build and make a request to the server + request, channel = self.make_request( + "GET", "url_preview?url=http://example.com", shorthand=False + ) + request.render(self.preview_url) + self.pump() + + # Extract Synapse's tcp client + client = self.reactor.tcpClients[0][2].buildProtocol(None) + + # Build a fake remote server to reply with + server = AccumulatingProtocol() + + # Connect the two together + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + + # Tell Synapse that it has received some data from the remote server + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + + # Move the reactor along until we get a response on our original channel + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} + ) + + # Check that the server received the Accept-Language header as part + # of the request from Synapse + self.assertIn( + ( + b"Accept-Language: en-UK\r\n" + b"Accept-Language: en-US;q=0.9\r\n" + b"Accept-Language: fr;q=0.8\r\n" + b"Accept-Language: *;q=0.7" + ), + server.data, + ) diff --git a/tests/server.py b/tests/server.py index e397ebe8fa..1644710aa0 100644 --- a/tests/server.py +++ b/tests/server.py @@ -20,6 +20,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.server import Site from synapse.http.site import SynapseRequest from synapse.util import Clock @@ -42,6 +43,7 @@ class FakeChannel(object): wire). """ + site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(default=attr.Factory(dict)) _producer = None @@ -161,7 +163,11 @@ def make_request( path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand - if shorthand and not path.startswith(b"/_matrix"): + if ( + shorthand + and not path.startswith(b"/_matrix") + and not path.startswith(b"/_synapse") + ): path = b"/_matrix/client/r0/" + path path = path.replace(b"//", b"/") @@ -172,9 +178,9 @@ def make_request( content = content.encode("utf8") site = FakeSite() - channel = FakeChannel(reactor) + channel = FakeChannel(site, reactor) - req = request(site, channel) + req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) req.postpath = list(map(unquote, path[1:].split(b"/"))) @@ -298,41 +304,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(cleanup_func, *args, **kwargs).result + server = _sth(cleanup_func, *args, **kwargs) - if isinstance(d, Failure): - d.raiseException() + database = server.config.database.get_single_database() # Make the thread pool synchronous. - clock = d.get_clock() - pool = d.get_db_pool() - - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) - - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + clock = server.get_clock() + + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True - return d + + return server def get_clock(): @@ -375,6 +382,7 @@ class FakeTransport(object): disconnecting = False disconnected = False + connected = True buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) @@ -391,11 +399,25 @@ class FakeTransport(object): self.disconnecting = True if self._protocol: self._protocol.connectionLost(reason) - self.disconnected = True + + # if we still have data to write, delay until that is done + if self.buffer: + logger.info( + "FakeTransport: Delaying disconnect until buffer is flushed" + ) + else: + self.connected = False + self.disconnected = True def abortConnection(self): logger.info("FakeTransport: abortConnection()") - self.loseConnection() + + if not self.disconnecting: + self.disconnecting = True + if self._protocol: + self._protocol.connectionLost(None) + + self.disconnected = True def pauseProducing(self): if not self.producer: @@ -426,6 +448,9 @@ class FakeTransport(object): self._reactor.callLater(0.0, _produce) def write(self, byt): + if self.disconnecting: + raise Exception("Writing to disconnecting FakeTransport") + self.buffer = self.buffer + byt # always actually do the write asynchronously. Some protocols (notably the @@ -470,6 +495,10 @@ class FakeTransport(object): if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) + if not self.buffer and self.disconnecting: + logger.info("FakeTransport: Buffer now empty, completing disconnect") + self.disconnected = True + def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: """ diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index cdf89e3383..99908edba3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -17,27 +17,43 @@ from mock import Mock from twisted.internet import defer -from synapse.api.constants import EventTypes, ServerNoticeMsgType +from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) from tests import unittest +from tests.unittest import override_config +from tests.utils import default_config class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") - hs_config["server_notices"] = { - "system_mxid_localpart": "server", - "system_mxid_display_name": "test display name", - "system_mxid_avatar_url": None, - "room_name": "Server Notices", - } + def default_config(self): + config = default_config("test") + + config.update( + { + "admin_contact": "mailto:user@test.com", + "limit_usage_by_mau": True, + "server_notices": { + "system_mxid_localpart": "server", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", + }, + } + ) + + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - hs = self.setup_test_homeserver(config=hs_config) - return hs + return config def prepare(self, reactor, clock, hs): self.server_notices_sender = self.hs.get_server_notices_sender() @@ -52,54 +68,41 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) ) - self._send_notice = self._rlsn._server_notices_manager.send_notice - self._rlsn._server_notices_manager.send_notice = Mock() - self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_events = Mock(return_value=defer.succeed({})) - + self._rlsn._server_notices_manager.send_notice = Mock( + return_value=defer.succeed(Mock()) + ) self._send_notice = self._rlsn._server_notices_manager.send_notice - self.hs.config.limit_usage_by_mau = True self.user_id = "@user_id:test" - # self.server_notices_mxid = "@server:test" - # self.server_notices_mxid_display_name = None - # self.server_notices_mxid_avatar_url = None - # self.server_notices_room_name = "Server Notices" - - self._rlsn._server_notices_manager.get_notice_room_for_user = Mock( - returnValue="" + self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( + return_value=defer.succeed("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock() - self._rlsn._store.get_tags_for_room = Mock(return_value={}) - self.hs.config.admin_contact = "mailto:user@test.com" - - def test_maybe_send_server_notice_to_user_flag_off(self): - """Tests cases where the flags indicate nothing to do""" - # test hs disabled case - self.hs.config.hs_disabled = True + self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) + self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) + @override_config({"hs_disabled": True}) + def test_maybe_send_server_notice_disabled_hs(self): + """If the HS is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) - self._send_notice.assert_not_called() - # Test when mau limiting disabled - self.hs.config.hs_disabled = False - self.hs.config.limit_usage_by_mau = False - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + @override_config({"limit_usage_by_mau": False}) + def test_maybe_send_server_notice_to_user_flag_off(self): + """If mau limiting is disabled, we should not send notices""" + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() @@ -109,7 +112,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -118,6 +121,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -126,20 +130,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check contents, but 2 calls == set blocking event - self.assertTrue(self._send_notice.call_count == 2) + self.assertEqual(self._send_notice.call_count, 2) def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -150,7 +153,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) @@ -158,8 +161,87 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() + @override_config({"mau_limit_alerting": False}) + def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): + """ + Test that when server is over MAU limit and alerting is suppressed, then + an alert message is not sent into the room + """ + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER + ), + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + self.assertEqual(self._send_notice.call_count, 0) + + @override_config({"mau_limit_alerting": False}) + def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): + """ + Test that when a server is disabled, that MAU limit alerting is ignored. + """ + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED + ), + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + # Would be better to check contents, but 2 calls == set blocking event + self.assertEqual(self._send_notice.call_count, 2) + + @override_config({"mau_limit_alerting": False}) + def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): + """ + When the room is already in a blocked state, test that when alerting + is suppressed that the room is returned to an unblocked state. + """ + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), + side_effect=ResourceLimitError( + 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER + ), + ) + + self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( + return_value=defer.succeed((True, [])) + ) + + mock_event = Mock( + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) + ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + + self._send_notice.assert_called_once() + class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def default_config(self): + c = super().default_config() + c["server_notices"] = { + "system_mxid_localpart": "server", + "system_mxid_display_name": None, + "system_mxid_avatar_url": None, + "room_name": "Test Server Notice Room", + } + c["limit_usage_by_mau"] = True + c["max_mau_value"] = 5 + c["admin_contact"] = "mailto:user@test.com" + return c + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() @@ -173,22 +255,14 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): if not isinstance(self._rlsn, ResourceLimitsServerNotices): raise Exception("Failed to find reference to ResourceLimitsServerNotices") - self.hs.config.limit_usage_by_mau = True - self.hs.config.hs_disabled = False - self.hs.config.max_mau_value = 5 - self.hs.config.server_notices_mxid = "@server:test" - self.hs.config.server_notices_mxid_display_name = None - self.hs.config.server_notices_mxid_avatar_url = None - self.hs.config.server_notices_room_name = "Test Server Notice Room" - self.user_id = "@user_id:test" - self.hs.config.admin_contact = "mailto:user@test.com" - def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock(return_value=1000) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(1000) + ) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -198,7 +272,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Now lets get the last load of messages in the service notice room and # check that there is only one server notice room_id = self.get_success( - self.server_notices_manager.get_notice_room_for_user(self.user_id) + self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) token = self.get_success(self.event_source.get_current_token()) @@ -218,3 +292,86 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): count += 1 self.assertEqual(count, 1) + + def test_no_invite_without_notice(self): + """Tests that a user doesn't get invited to a server notices room without a + server notice being sent. + + The scenario for this test is a single user on a server where the MAU limit + hasn't been reached (since it's the only user and the limit is 5), so users + shouldn't receive a server notice. + """ + self.register_user("user", "password") + tok = self.login("user", "password") + + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + invites = channel.json_body["rooms"]["invite"] + self.assertEqual(len(invites), 0, invites) + + def test_invite_with_notice(self): + """Tests that, if the MAU limit is hit, the server notices user invites each user + to a room in which it has sent a notice. + """ + user_id, tok, room_id = self._trigger_notice_and_join() + + # Sync again to retrieve the events in the room, so we can check whether this + # room has a notice in it. + request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) + self.render(request) + + # Scan the events in the room to search for a message from the server notices + # user. + events = channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + notice_in_room = False + for event in events: + if ( + event["type"] == EventTypes.Message + and event["sender"] == self.hs.config.server_notices_mxid + ): + notice_in_room = True + + self.assertTrue(notice_in_room, "No server notice in room") + + def _trigger_notice_and_join(self): + """Creates enough active users to hit the MAU limit and trigger a system notice + about it, then joins the system notices room with one of the users created. + + Returns: + user_id (str): The ID of the user that joined the room. + tok (str): The access token of the user that joined the room. + room_id (str): The ID of the room that's been joined. + """ + user_id = None + tok = None + invites = [] + + # Register as many users as the MAU limit allows. + for i in range(self.hs.config.max_mau_value): + localpart = "user%d" % i + user_id = self.register_user(localpart, "password") + tok = self.login(localpart, "password") + + # Sync with the user's token to mark the user as active. + request, channel = self.make_request( + "GET", "/sync?timeout=0", access_token=tok, + ) + self.render(request) + + # Also retrieves the list of invites for this user. We don't care about that + # one except if we're processing the last user, which should have received an + # invite to a room with a server notice about the MAU limit being reached. + # We could also pick another user and sync with it, which would return an + # invite to a system notices room, but it doesn't matter which user we're + # using so we use the last one because it saves us an extra sync. + invites = channel.json_body["rooms"]["invite"] + + # Make sure we have an invite to process. + self.assertEqual(len(invites), 1, invites) + + # Join the room. + room_id = list(invites.keys())[0] + self.helper.join(room=room_id, user=user_id, tok=tok) + + return user_id, tok, room_id diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8d3845c870..a44960203e 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -22,7 +22,7 @@ import attr from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store from synapse.types import EventID @@ -58,6 +58,7 @@ class FakeEvent(object): self.type = type self.state_key = state_key self.content = content + self.room_id = ROOM_ID def to_event(self, auth_events, prev_events): """Given the auth_events and prev_events, convert to a Frozen Event @@ -88,7 +89,7 @@ class FakeEvent(object): if self.state_key is not None: event_dict["state_key"] = self.state_key - return FrozenEvent(event_dict) + return make_event_from_dict(event_dict) # All graphs start with this set of events @@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, @@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None, @@ -600,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -614,7 +617,6 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -634,3 +636,9 @@ class TestStateResolutionStore(object): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index dd49a14524..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo"), d.result) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,13 +367,13 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "there")]), + {(1, "user1", "hello"), (2, "user2", "there")}, ) # Update only user2 @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,11 +394,11 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "bleb")]), + {(1, "user1", "hello"), (2, "user2", "bleb")}, ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 622b16a071..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,10 +24,11 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -42,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -54,7 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -65,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): pass def _add_appservice(self, as_token, id, url, hs_token, sender): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token=hs_token, - id=id, - sender_localpart=sender, - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": hs_token, + "id": id, + "sender_localpart": sender, + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -106,12 +110,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -123,17 +124,25 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine - def _add_service(self, url, as_token, id): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token="something", - id=id, - sender_localpart="a_sender", - namespaces={}, + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs ) + + def _add_service(self, url, as_token, id): + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": "something", + "id": id, + "sender_localpart": "a_sender", + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -375,15 +384,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) self.assertEquals(2, len(services)) self.assertEquals( - set([self.as_list[2]["id"], self.as_list[0]["id"]]), - set([services[0].id, services[1].id]), + {self.as_list[2]["id"], self.as_list[0]["id"]}, + {services[0].id, services[1].id}, ) # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): @@ -413,10 +422,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -428,11 +440,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -449,11 +464,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fabe3fbc0..940b166129 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -2,74 +2,90 @@ from mock import Mock from twisted.internet import defer +from synapse.storage.background_updates import BackgroundUpdater + from tests import unittest -from tests.utils import setup_test_homeserver -class BackgroundUpdateTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - self.store = hs.get_datastore() - self.clock = hs.get_clock() +class BackgroundUpdateTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + # the base test class should have run the real bg updates for us + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) self.update_handler = Mock() - - yield self.store.register_background_update_handler( + self.updates.register_background_update_handler( "test_update", self.update_handler ) - # run the real background updates, to get them out the way - # (perhaps we should run them as part of the test HS setup, since we - # run all of the other schema setup stuff there?) - while True: - res = yield self.store.do_next_background_update(1000) - if res is None: - break - - @defer.inlineCallbacks def test_do_background_update(self): - desired_count = 1000 + # the time we claim each update takes duration_ms = 42 + # the target runtime for each bg update + target_background_update_duration_ms = 50000 + + store = self.hs.get_datastore() + self.get_success( + store.db.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): - self.clock.advance_time_msec(count * duration_ms) + yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.store.runInteraction( + yield store.db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.updates._background_update_progress_txn, "test_update", progress, ) return count self.update_handler.side_effect = update - - yield self.store.start_background_update("test_update", {"my_key": 1}) - self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) - self.assertIsNotNone(result) + res = self.get_success( + self.updates.do_next_background_update( + target_background_update_duration_ms + ), + by=0.1, + ) + self.assertFalse(res) + + # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update + # we should now get run with a much bigger number of items to update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, target_background_update_duration_ms / duration_ms, places=0, + ) + yield self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) - self.assertIsNotNone(result) - self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) + self.assertFalse(result) + self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) - self.assertIsNone(result) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) + self.assertTrue(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index c778de1f0c..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,6 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -50,22 +51,25 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + config.caches = Mock() + config.caches.event_cache_size = 1 + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +81,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +96,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.db.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +110,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +126,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +141,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.db.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +169,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +184,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.db.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index e9e2d5337c..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -14,7 +14,13 @@ # limitations under the License. import os.path +from unittest.mock import patch +from mock import Mock + +import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database from synapse.types import Requester, UserID @@ -33,17 +39,21 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, + "data_stores", + "main", "schema", "delta", "54", @@ -54,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + self.store.db.runInteraction( + "test_delete_forward_extremities", run_delta_file + ) ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of @@ -118,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -156,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -211,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual( - set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) - ) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -221,10 +235,18 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) + self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): + CONSENT_VERSION = "1" + EXTREMITIES_COUNT = 50 + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + def make_homeserver(self, reactor, clock): config = self.default_config() config["cleanup_extremities_with_dummy_events"] = True @@ -233,28 +255,39 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.room_creator = homeserver.get_room_creation_handler() + self.event_creator_handler = homeserver.get_event_creation_handler() # Create a test user and room - self.user = UserID("alice", "test") + self.user = UserID.from_string(self.register_user("user1", "password")) + self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] + self.event_creator = homeserver.get_event_creation_handler() + homeserver.config.user_consent_version = self.CONSENT_VERSION def test_send_dummy_event(self): - # Create a bushy graph with 50 extremities. + self._create_extremity_rich_graph() - event_id_start = self.create_and_send_event(self.room_id, self.user) - - for _ in range(50): - self.create_and_send_event( - self.room_id, self.user, prev_event_ids=[event_id_start] - ) + # Pump the reactor repeatedly so that the background updates have a + # chance to run. + self.pump(10 * 60) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(len(latest_event_ids), 50) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_events_when_insufficient_power(self): + self._create_extremity_rich_graph() + # Criple power levels + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + body={"users": {str(self.user): -1}}, + tok=self.token1, + ) # Pump the reactor repeatedly so that the background updates have a # chance to run. self.pump(10 * 60) @@ -262,4 +295,108 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) + # Check that the room has not been pruned + self.assertTrue(len(latest_event_ids) > 10) + + # New user with regular levels + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.helper.join(self.room_id, user2, tok=token2) + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_event_without_consent(self): + self._create_extremity_rich_graph() + self._enable_consent_checking() + + # Pump the reactor repeatedly so that the background updates have a + # chance to run. Attempt to add dummy event with user that has not consented + # Check that dummy event send fails. + self.pump(10 * 60) + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT) + + # Create new user, and add consent + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.get_success( + self.store.user_set_consent_version(user2, self.CONSENT_VERSION) + ) + self.helper.join(self.room_id, user2, tok=token2) + + # Background updates should now cause a dummy event to be added to the graph + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) + def test_expiry_logic(self): + """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() + expires old entries correctly. + """ + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "1" + ] = 100000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "2" + ] = 200000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "3" + ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + # All entries within time frame + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 3, + ) + # Oldest room to expire + self.pump(1) + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 2, + ) + # All rooms to expire + self.pump(2) + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 0, + ) + + def _create_extremity_rich_graph(self): + """Helper method to create bushy graph on demand""" + + event_id_start = self.create_and_send_event(self.room_id, self.user) + + for _ in range(self.EXTREMITIES_COUNT): + self.create_and_send_event( + self.room_id, self.user, prev_event_ids=[event_id_start] + ) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(len(latest_event_ids), 50) + + def _enable_consent_checking(self): + """Helper method to enable consent checking""" + self.event_creator._block_events_without_consent_error = "No consent from user" + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creator._consent_uri_builder = consent_uri_builder diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 09305c3bf1..3b483bc7f0 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.unittest import override_config class ClientIpStoreTestCase(unittest.HomeserverTestCase): @@ -37,9 +38,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,15 +52,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", - "access_token": "access_token", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -82,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -113,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -134,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) def test_disabled_monthly_active_user(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success( self.store.insert_client_ip( @@ -146,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_full(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 lots_of_users = 100 user_id = "@user:server" @@ -163,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -181,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_updating_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) @@ -201,6 +201,173 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + def test_devices_last_seen_bg_update(self): + # First make sure we have completed all updates. + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + # Force persisting to disk + self.reactor.advance(200) + + # But clear the associated entry in devices table + self.get_success( + self.store.db.simple_update( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues={"last_seen": None, "ip": None, "user_agent": None}, + desc="test_devices_last_seen_bg_update", + ) + ) + + # We should now get nulls when querying + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": None, + "user_agent": None, + "last_seen": None, + }, + r, + ) + + # Register the background update to run again. + self.get_success( + self.store.db.simple_insert( + table="background_updates", + values={ + "update_name": "devices_last_seen", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db.updates._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + # We should now get the correct result again + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + + def test_old_user_ips_pruned(self): + # First make sure we have completed all updates. + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + + # Force persisting to disk + self.reactor.advance(200) + + # We should see that in the DB + result = self.get_success( + self.store.db.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual( + result, + [ + { + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": device_id, + "last_seen": 0, + } + ], + ) + + # Now advance by a couple of months + self.reactor.advance(60 * 24 * 60 * 60) + + # We should get no results. + result = self.get_success( + self.store.db.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual(result, []) + + # But we should still get the correct values for the device + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + class ClientIpAuthTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py new file mode 100644 index 0000000000..5a77c84962 --- /dev/null +++ b/tests/storage/test_database.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import make_tuple_comparison_clause +from synapse.storage.engines import BaseDatabaseEngine + +from tests import unittest + + +def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: + # returns a DatabaseEngine, circumventing the abc mechanism + # any kwargs are set as attributes on the class before instantiating it + t = type( + "TestBaseDatabaseEngine", + (BaseDatabaseEngine,), + dict(BaseDatabaseEngine.__dict__), + ) + # defeat the abc mechanism + t.__abstractmethods__ = set() + for k, v in kwargs.items(): + setattr(t, k, v) + return t(None, None) + + +class TupleComparisonClauseTestCase(unittest.TestCase): + def test_native_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=True) + clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + self.assertEqual(clause, "(a,b) > (?,?)") + self.assertEqual(args, [1, 2]) + + def test_emulated_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=False) + clause, args = make_tuple_comparison_clause( + db_engine, [("a", 1), ("b", 2), ("c", 3)] + ) + self.assertEqual( + clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" + ) + self.assertEqual(args, [1, 1, 2, 2, 3]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 3cc18f9f1c..c2539b353a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks - def test_get_devices_by_remote(self): + def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id @@ -81,63 +81,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "somehost", -1, limit=100 ) # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - @defer.inlineCallbacks - def test_get_devices_by_remote_limited(self): - # Test breaking the update limit in 1, 101, and 1 device_id segments - - # first add one device - device_ids1 = ["device_id0"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"] - ) - - # then add 101 - device_ids2 = ["device_id" + str(i + 1) for i in range(101)] - yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"] - ) - - # then one more - device_ids3 = ["newdevice"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"] - ) - - # - # now read them back. - # - - # first we should get a single update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", -1, limit=100 - ) - self._check_devices_in_updates(device_ids1, device_updates) - - # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self.assertEqual(len(device_updates), 0) - - # The 101 devices should've been cleared, so we should now just get one device - # update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self._check_devices_in_updates(device_ids3, device_updates) - def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) - received_device_ids = {update["device_id"] for update in device_updates} + received_device_ids = { + update["device_id"] for edu_type, update in device_updates + } self.assertEqual(received_device_ids, set(expected_device_ids)) @defer.inlineCallbacks diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py new file mode 100644 index 0000000000..35dafbb904 --- /dev/null +++ b/tests/storage/test_e2e_room_keys.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests import unittest + +# sample room_key data for use in the tests +room_key = { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK", +} + + +class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + self.store = hs.get_datastore() + return hs + + def test_room_keys_version_delete(self): + # test that deleting a room key backup deletes the keys + version1 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + ) + + self.get_success( + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] + ) + ) + + version2 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + ) + + self.get_success( + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] + ) + ) + + # make sure the keys were stored properly + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) + self.assertEqual(len(keys["rooms"]), 1) + + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) + self.assertEqual(len(keys["rooms"]), 1) + + # delete version1 + self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1)) + + # make sure the key from version1 is gone, and the key from version2 is + # still there + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) + self.assertEqual(len(keys["rooms"]), 0) + + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) + self.assertEqual(len(keys["rooms"]), 1) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index c8ece15284..398d546280 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev) + self.assertDictContainsSubset(json, dev) @defer.inlineCallbacks def test_reupload_key(self): @@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset( - {"keys": json, "device_display_name": "display_name"}, dev + {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) @defer.inlineCallbacks @@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys("user1", "device1", now, "json11") - yield self.store.set_e2e_device_keys("user1", "device2", now, "json12") - yield self.store.set_e2e_device_keys("user2", "device1", now, "json21") - yield self.store.set_e2e_device_keys("user2", "device2", now, "json22") + yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) + yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) + yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) + yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) res = yield self.store.get_e2e_device_keys( (("user1", "device1"), ("user2", "device2")) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 86c7ac350d..3aeec0dc0f 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import tests.unittest import tests.utils -class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_get_prev_events_for_room(self): room_id = "@ROOM:local" @@ -57,21 +52,182 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): "(event_id, algorithm, hash) " "VALUES (?, 'sha256', ?)" ), - (event_id, b"ffff"), + (event_id, bytearray(b"ffff")), ) - for i in range(0, 11): - yield self.store.runInteraction("insert", insert_event, i) + for i in range(0, 20): + self.get_success(self.store.db.runInteraction("insert", insert_event, i)) - # this should get the last five and five others - r = yield self.store.get_prev_events_for_room(room_id) + # this should get the last ten + r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) - for i in range(0, 5): - el = r[i] - depth = el[2] - self.assertEqual(10 - i, depth) - - for i in range(5, 5): - el = r[i] - depth = el[2] - self.assertLessEqual(5, depth) + for i in range(0, 10): + self.assertEqual("$event_%i:local" % (19 - i), r[i]) + + def test_get_rooms_with_many_extremities(self): + room1 = "#room1" + room2 = "#room2" + room3 = "#room3" + + def insert_event(txn, i, room_id): + event_id = "$event_%i:local" % i + txn.execute( + ( + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" + ), + (room_id, event_id), + ) + + for i in range(0, 20): + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room1) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room2) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room3) + ) + + # Test simple case + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [])) + self.assertEqual(len(r), 3) + + # Does filter work? + + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1])) + self.assertTrue(room2 in r) + self.assertTrue(room3 in r) + self.assertEqual(len(r), 2) + + r = self.get_success( + self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + ) + self.assertEqual(r, [room3]) + + # Does filter and limit work? + + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) + self.assertTrue(r == [room2] or r == [room3]) + + def test_auth_difference(self): + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } + + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + + def insert_event(txn, event_id, stream_ordering): + + depth = depth_map[event_id] + + self.store.db.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.store.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + {"event_id": event_id, "room_id": room_id, "auth_id": a} + for a in auth_graph[event_id] + ], + ) + + next_stream_ordering = 0 + for event_id in auth_graph: + next_stream_ordering += 1 + self.get_success( + self.store.db.runInteraction( + "insert", insert_event, event_id, next_stream_ordering + ) + ) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + self.assertSetEqual(difference, set()) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index f26ff57a18..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None @@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) ) - expected = set( - [ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - ] - ) + expected = { + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + } self.assertEqual(items, expected) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b114c6fb1d..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.persist_events_store = hs.get_datastores().persist_events @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -55,7 +56,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.runInteraction( + counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -74,20 +75,20 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield self.store.add_push_actions_to_staging( event.event_id, {user_id: action} ) - yield self.store.runInteraction( + yield self.store.db.runInteraction( "", - self.store._set_push_actions_for_event_and_users_txn, + self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], [(event, None)], ) def _rotate(stream): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -116,7 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.db.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +136,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.db.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py new file mode 100644 index 0000000000..55e9ecf264 --- /dev/null +++ b/tests/storage/test_id_generators.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.storage.database import Database +from synapse.storage.util.id_generators import MultiWriterIdGenerator + +from tests.unittest import HomeserverTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class MultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db = self.store.db # type: Database + + self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + ) + + return self.get_success(self.db.runWithConnection(_create)) + + def _insert_rows(self, instance_name: str, number: int): + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", + (instance_name,), + ) + + self.get_success(self.db.runInteraction("test_single_instance", _insert)) + + def test_empty(self): + """Test an ID generator against an empty database gives sensible + current positions. + """ + + id_gen = self._create_id_generator() + + # The table is empty so we expect an empty map for positions + self.assertEqual(id_gen.get_positions(), {}) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(_get_next_async()) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) + + def test_multi_instance(self): + """Test that reads and writes from multiple processes are handled + correctly. + """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator("first") + second_id_gen = self._create_id_generator("second") + + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token("first"), 3) + self.assertEqual(first_id_gen.get_current_token("second"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await first_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) + + # However the ID gen on the second instance won't have seen the update + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + + # ... but calling `get_next` on the second instance should give a unique + # stream ID + + async def _get_next_async(): + with await second_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 9) + + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + + # If the second ID gen gets told about the first, it correctly updates + second_id_gen.advance("first", 8) + self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + + def test_get_next_txn(self): + """Test that the `get_next_txn` function works correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + def _get_next_txn(txn): + stream_id = id_gen.get_next_txn(txn) + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(self.db.runInteraction("test", _get_next_txn)) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index e07ff01201..95f309fbbc 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key +import unpaddedbase64 from twisted.internet.defer import Deferred @@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult import tests.unittest -KEY_1 = signedjson.key.decode_verify_key_base64( - "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" + +def decode_verify_key_base64(key_id: str, key_base64: str): + key_bytes = unpaddedbase64.decode_base64(key_base64) + return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) + + +KEY_1 = decode_verify_key_base64( + "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" ) -KEY_2 = signedjson.key.decode_verify_key_base64( - "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" +KEY_2 = decode_verify_key_base64( + "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" ) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py new file mode 100644 index 0000000000..ab0df5ea93 --- /dev/null +++ b/tests/storage/test_main.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Awesome Technologies Innovationslabor GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.types import UserID + +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DataStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + self.user = UserID.from_string("@abcde:test") + self.displayname = "Frank" + + @defer.inlineCallbacks + def test_get_users_paginate(self): + yield self.store.register_user(self.user.to_string(), "pass") + yield self.store.create_profile(self.user.localpart) + yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + + users, total = yield self.store.get_users_paginate( + 0, 10, name="bc", guests=False + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 1494650d10..9c04e92577 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,152 +19,222 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 +def gen_3pids(count): + """Generate `count` threepids as a list.""" + return [ + {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count) + ] + + class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") + + config.update({"limit_usage_by_mau": True, "max_mau_value": 50}) + + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - hs = self.setup_test_homeserver() - self.store = hs.get_datastore() - hs.config.limit_usage_by_mau = True - hs.config.max_mau_value = 50 + return config + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() # Advance the clock a bit reactor.advance(FORTY_DAYS) - return hs - + @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - self.hs.config.max_mau_value = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + + # register three users, of which two have reserved 3pids, and a third + # which is a support user. user1 = "@user1:server" - user1_email = "user1@matrix.org" + user1_email = threepids[0]["address"] user2 = "@user2:server" - user2_email = "user2@matrix.org" + user2_email = threepids[1]["address"] user3 = "@user3:server" - user3_email = "user3@matrix.org" - - threepids = [ - {"medium": "email", "address": user1_email}, - {"medium": "email", "address": user2_email}, - {"medium": "email", "address": user3_email}, - ] - # -1 because user3 is a support user and does not count - user_num = len(threepids) - 1 - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.store.register_user( - user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT + self.get_success(self.store.register_user(user_id=user1)) + self.get_success(self.store.register_user(user_id=user2)) + self.get_success( + self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT) ) - self.pump() now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) - self.store.runInteraction( - "initialise", self.store._initialise_reserved_users, threepids + # XXX why are we doing this here? this function is only run at startup + # so it is odd to re-run it here. + self.get_success( + self.store.db.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) ) - self.pump() - active_count = self.store.get_monthly_active_count() + # the number of users we expect will be counted against the mau limit + # -1 because user3 is a support user and does not count + user_num = len(threepids) - 1 - # Test total counts, ensure user3 (support user) is not counted - self.assertEquals(self.get_success(active_count), user_num) + # Check the number of active users. Ensure user3 (support user) is not counted + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num) - # Test user is marked as active - timestamp = self.store.user_last_seen_monthly_active(user1) - self.assertTrue(self.get_success(timestamp)) - timestamp = self.store.user_last_seen_monthly_active(user2) - self.assertTrue(self.get_success(timestamp)) + # Test each of the registered users is marked as active + timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) + self.assertGreater(timestamp, 0) + timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) + self.assertGreater(timestamp, 0) - # Test that users are never removed from the db. + # Test that users with reserved 3pids are not removed from the MAU table + # XXX some of this is redundant. poking things into the config shouldn't + # work, and in any case it's not obvious what we expect to happen when + # we advance the reactor. self.hs.config.max_mau_value = 0 - self.reactor.advance(FORTY_DAYS) + self.hs.config.max_mau_value = 5 - self.store.reap_monthly_active_users() - self.pump() + self.get_success(self.store.reap_monthly_active_users()) - active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num) - # Test that regular users are removed from the db + # Add some more users and check they are counted as active ru_count = 2 - self.store.upsert_monthly_active_user("@ru1:server") - self.store.upsert_monthly_active_user("@ru2:server") - self.pump() - active_count = self.store.get_monthly_active_count() - self.assertEqual(self.get_success(active_count), user_num + ru_count) - self.hs.config.max_mau_value = user_num - self.store.reap_monthly_active_users() - self.pump() + self.get_success(self.store.upsert_monthly_active_user("@ru1:server")) + self.get_success(self.store.upsert_monthly_active_user("@ru2:server")) + + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num + ru_count) - active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + # now run the reaper and check that the number of active users is reduced + # to max_mau_value + self.get_success(self.store.reap_monthly_active_users()) + + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, 3) def test_can_insert_and_count_mau(self): - count = self.store.get_monthly_active_count() - self.assertEqual(0, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) - self.store.upsert_monthly_active_user("@user:server") - self.pump() + d = self.store.upsert_monthly_active_user("@user:server") + self.get_success(d) - count = self.store.get_monthly_active_count() - self.assertEqual(1, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 1) def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = self.store.user_last_seen_monthly_active(user_id1) - self.assertFalse(self.get_success(result) == 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + self.assertNotEqual(result, 0) - self.store.upsert_monthly_active_user(user_id1) - self.store.upsert_monthly_active_user(user_id2) - self.pump() + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.store.upsert_monthly_active_user(user_id2)) - result = self.store.user_last_seen_monthly_active(user_id1) - self.assertGreater(self.get_success(result), 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + self.assertGreater(result, 0) - result = self.store.user_last_seen_monthly_active(user_id3) - self.assertNotEqual(self.get_success(result), 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) + self.assertNotEqual(result, 0) + @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): - self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): - self.store.upsert_monthly_active_user("@user%d:server" % i) - self.pump() + self.get_success( + self.store.upsert_monthly_active_user("@user%d:server" % i) + ) - count = self.store.get_monthly_active_count() - self.assertTrue(self.get_success(count), initial_users) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, initial_users) - self.store.reap_monthly_active_users() - self.pump() - count = self.store.get_monthly_active_count() - self.assertEquals( - self.get_success(count), initial_users - self.hs.config.max_mau_value - ) + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, self.hs.config.max_mau_value) self.reactor.advance(FORTY_DAYS) - self.store.reap_monthly_active_users() - self.pump() - count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(count), 0) + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + # Note that below says mau_limit (no s), this is the name of the config + # value, although it gets stored on the config object as mau_limits. + @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) + def test_reap_monthly_active_users_reserved_users(self): + """ Tests that reaping correctly handles reaping where reserved users are + present""" + threepids = self.hs.config.mau_limits_reserved_threepids + initial_users = len(threepids) + reserved_user_number = initial_users - 1 + for i in range(initial_users): + user = "@user%d:server" % i + email = "user%d@matrix.org" % i + + self.get_success(self.store.upsert_monthly_active_user(user)) + + # Need to ensure that the most recent entries in the + # monthly_active_users table are reserved + now = int(self.hs.get_clock().time_msec()) + if i != 0: + self.get_success( + self.store.register_user(user_id=user, password_hash=None) + ) + self.get_success( + self.store.user_add_threepid(user, "email", email, now, now) + ) + + d = self.store.db.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, initial_users) + + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), reserved_user_number) + + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, self.hs.config.max_mau_value) def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list user_id = "@user_id:host" - self.store.register_user(user_id=user_id, password_hash=None, make_guest=True) + + d = self.store.register_user( + user_id=user_id, password_hash=None, make_guest=True + ) + self.get_success(d) + self.store.upsert_monthly_active_user = Mock() - self.store.populate_monthly_active_users(user_id) - self.pump() + + d = self.store.populate_monthly_active_users(user_id) + self.get_success(d) + self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): @@ -175,8 +245,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - self.store.populate_monthly_active_users("user_id") - self.pump() + d = self.store.populate_monthly_active_users("user_id") + self.get_success(d) + self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): @@ -186,80 +257,132 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) - self.store.populate_monthly_active_users("user_id") - self.pump() + + d = self.store.populate_monthly_active_users("user_id") + self.get_success(d) + self.store.upsert_monthly_active_user.assert_not_called() def test_get_reserved_real_user_account(self): # Test no reserved users, or reserved threepids - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) - # Test reserved users but no registered users + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), 0) + # Test reserved users but no registered users user1 = "@user1:example.com" user2 = "@user2:example.com" + user1_email = "user1@example.com" user2_email = "user2@example.com" threepids = [ {"medium": "email", "address": user1_email}, {"medium": "email", "address": user2_email}, ] + self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + d = self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) + self.get_success(d) - self.pump() - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), 0) - # Test reserved registed users - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.pump() + # Test reserved registered users + self.get_success(self.store.register_user(user_id=user1, password_hash=None)) + self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), len(threepids)) + + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), len(threepids)) def test_support_user_not_add_to_mau_limits(self): support_user_id = "@support:test" - count = self.store.get_monthly_active_count() - self.pump() - self.assertEqual(self.get_success(count), 0) - self.store.register_user( + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + d = self.store.register_user( user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) + self.get_success(d) - self.store.upsert_monthly_active_user(support_user_id) - count = self.store.get_monthly_active_count() - self.pump() - self.assertEqual(self.get_success(count), 0) + d = self.store.upsert_monthly_active_user(support_user_id) + self.get_success(d) - def test_track_monthly_users_without_cap(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - self.hs.config.max_mau_value = 1 # should not matter + d = self.store.get_monthly_active_count() + count = self.get_success(d) + self.assertEqual(count, 0) - count = self.store.get_monthly_active_count() - self.assertEqual(0, self.get_success(count)) + # Note that the max_mau_value setting should not matter. + @override_config( + {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} + ) + def test_track_monthly_users_without_cap(self): + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(0, count) - self.store.upsert_monthly_active_user("@user1:server") - self.store.upsert_monthly_active_user("@user2:server") - self.pump() + self.get_success(self.store.upsert_monthly_active_user("@user1:server")) + self.get_success(self.store.upsert_monthly_active_user("@user2:server")) - count = self.store.get_monthly_active_count() - self.assertEqual(2, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(2, count) + @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = False self.store.upsert_monthly_active_user = Mock() - self.store.populate_monthly_active_users("@user:sever") - self.pump() + self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_monthly_active_count_by_service(self): + appservice1_user1 = "@appservice1_user1:example.com" + appservice1_user2 = "@appservice1_user2:example.com" + + appservice2_user1 = "@appservice2_user1:example.com" + native_user1 = "@native_user1:example.com" + + service1 = "service1" + service2 = "service2" + native = "native" + + self.get_success( + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + ) + self.get_success( + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + ) + self.get_success( + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + ) + self.get_success( + self.store.register_user(user_id=native_user1, password_hash=None) + ) + + count = self.get_success(self.store.get_monthly_active_count_by_service()) + self.assertEqual(count, {}) + + self.get_success(self.store.upsert_monthly_active_user(native_user1)) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user1)) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user2)) + self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 4) + + d = self.store.get_monthly_active_count_by_service() + result = self.get_success(d) + + self.assertEqual(result[service1], 2) + self.assertEqual(result[service2], 1) + self.assertEqual(result[native], 1) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 45824bd3b2..9b6f7211ae 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index f671599cb8..b9fafaa1a6 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase): third = self.helper.send(self.room_id, body="test3") last = self.helper.send(self.room_id, body="test4") - storage = self.hs.get_datastore() + store = self.hs.get_datastore() + storage = self.hs.get_storage() # Get the topological token - event = storage.get_topological_token_for_event(last["event_id"]) + event = store.get_topological_token_for_event(last["event_id"]) self.pump() event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = storage.purge_events.purge_history(self.room_id, event, True) self.pump() self.assertEqual(self.successResultOf(purge), None) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) + get_first = store.get_event(first["event_id"]) + get_second = store.get_event(second["event_id"]) + get_third = store.get_event(third["event_id"]) + get_last = store.get_event(last["event_id"]) self.pump() # 1-3 should fail and last will succeed, meaning that 1-3 are deleted diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index deecfad9fb..db3667dc43 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -116,7 +117,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) + + return event def test_redact(self): self.get_success( @@ -235,8 +238,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): built_event = yield self._base_builder.build(prev_event_ids) - built_event.event_id = self._event_id - built_event._event_dict["event_id"] = self._event_id + + built_event._event_id = self._event_id + built_event._dict["event_id"] = self._event_id + assert built_event.event_id == self._event_id + return built_event @property @@ -261,7 +267,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.store.persist_event(event_1, context_1)) + self.get_success(self.storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -280,7 +286,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.store.persist_event(event_2, context_2)) + self.get_success(self.storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -335,7 +341,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -353,7 +359,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -361,3 +367,72 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) self.assert_dict({"content": {}}, json.loads(event_json)) + + def test_redact_redaction(self): + """Tests that we can redact a redaction and can fetch it again. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + + first_redact_event = self.get_success( + self.inject_redaction( + self.room1, msg_event.event_id, self.u_alice, "Redacting message" + ) + ) + + self.get_success( + self.inject_redaction( + self.room1, + first_redact_event.event_id, + self.u_alice, + "Redacting redaction", + ) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(first_redact_event.event_id, allow_none=True) + ) + + def test_store_redacted_redaction(self): + """Tests that we can store a redacted redaction. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "foo"}, + }, + ) + + redaction_event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.storage.persistence.persist_event(redaction_event, context) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(redaction_event.event_id, allow_none=True) + ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4578cc3b60..71a40a0a49 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() @@ -44,12 +43,14 @@ class RegistrationStoreTestCase(unittest.TestCase): # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, "password_hash": self.pwhash, + "admin": 0, "is_guest": 0, "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 1000, "user_type": None, + "deactivated": 0, }, (yield self.store.get_user_by_id(self.user_id)), ) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1bee45706f..3b78d48896 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID from tests import unittest @@ -40,6 +41,7 @@ class RoomStoreTestCase(unittest.TestCase): self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks @@ -53,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats(self): + self.assertDictContainsSubset( + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "public": True, + }, + (yield self.store.get_room_with_stats(self.room.to_string())), + ) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks @@ -62,17 +75,21 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.store.persist_event( + yield self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 447a3c6ffb..5dd46005e6 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -16,13 +16,14 @@ from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions +from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID from tests import unittest +from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -39,13 +40,11 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -54,33 +53,13 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - - self.get_success(self.store.persist_event(event, context)) - - return event - def test_one_member(self): # Alice creates the room, and is automatically joined self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) rooms_for_user = self.get_success( - self.store.get_rooms_for_user_where_membership_is( + self.store.get_rooms_for_local_user_where_membership_is( self.u_alice, [Membership.JOIN] ) ) @@ -137,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) + def test_get_joined_users_from_context(self): + room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + bob_event = event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) + + # first, create a regular event + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) + + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + + # Regression test for #7376: create a state event whose key matches bob's + # user_id, but which is *not* a membership event, and persist that; then check + # that `get_joined_users_from_context` returns the correct users for the next event. + non_member_event = event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): @@ -145,8 +170,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -155,7 +184,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -166,8 +195,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5c2cf3c2db..0b88308ff4 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -43,7 +45,10 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks @@ -63,7 +68,7 @@ class StateStoreTestCase(tests.unittest.TestCase): builder ) - yield self.store.persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @@ -82,7 +87,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids( + state_group_map = yield self.storage.state.get_state_groups_ids( self.room, [e2.event_id] ) self.assertEqual(len(state_group_map), 1) @@ -101,7 +106,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) + state_group_map = yield self.storage.state.get_state_groups( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -141,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event(e5.event_id) + state = yield self.storage.state.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -157,21 +164,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -181,7 +188,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, @@ -199,7 +206,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -215,13 +222,18 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group_ids = yield self.storage.state.get_state_groups_ids( + room_id, [e5.event_id] + ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -237,8 +249,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -250,8 +265,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -267,8 +285,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -287,8 +308,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -304,8 +328,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -317,8 +344,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -331,9 +361,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) @@ -346,21 +378,23 @@ class StateStoreTestCase(tests.unittest.TestCase): ) state_dict_ids.pop((e2.type, e2.state_key)) - self.store._state_group_cache.invalidate(group) - self.store._state_group_cache.update( - sequence=self.store._state_group_cache.sequence, + self.state_datastore._state_group_cache.invalidate(group) + self.state_datastore._state_group_cache.update( + sequence=self.state_datastore._state_group_cache.sequence, key=group, value=state_dict_ids, # list fetched keys so it knows it's partial fetched_keys=((e1.type, e1.state_key),), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) - self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertEqual(known_absent, {(e1.type, e1.state_key)}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ @@ -369,8 +403,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -381,8 +418,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -394,8 +434,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -405,8 +448,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -424,8 +470,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -435,8 +484,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -448,8 +500,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -459,8 +514,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index a771d5af29..8e817e2c7f 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.retryutils import MAX_RETRY_INTERVAL + from tests.unittest import HomeserverTestCase @@ -45,3 +47,12 @@ class TransactionStoreTestCase(HomeserverTestCase): """ d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) + + def test_large_destination_retry(self): + d = self.store.set_destination_retry_timings( + "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL + ) + self.get_success(d) + + d = self.store.get_destination_retry_timings("example.com") + self.get_success(d) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index d7d244ce97..6a545d2eb0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 8b2741d277..69b4c5d6c2 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -18,7 +18,8 @@ import unittest from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict +from synapse.types import get_domain_from_id class EventAuthTestCase(unittest.TestCase): @@ -37,7 +38,7 @@ class EventAuthTestCase(unittest.TestCase): # creator should be able to send state event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(creator), auth_events, do_sig_check=False, @@ -47,11 +48,11 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(joiner), auth_events, do_sig_check=False, - ), + ) def test_state_default_level(self): """ @@ -76,7 +77,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(pleb), auth_events, do_sig_check=False, @@ -84,11 +85,112 @@ class EventAuthTestCase(unittest.TestCase): # king should be able to send state event_auth.check( - RoomVersions.V1.identifier, - _random_state_event(king), + RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, + ) + + def test_alias_event(self): + """Alias events have special behavior up through room version 6.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False, + ) + + # Reject an event with no state key. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key=""), + auth_events, + do_sig_check=False, + ) + + # If the domain of the sender does not match the state key, reject. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V1, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Note that the member does *not* need to be in the room. + event_auth.check( + RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False, + ) + + def test_msc2432_alias_event(self): + """After MSC2432, alias events have no special behavior.""" + creator = "@creator:example.com" + other = "@other:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + } + + # creator should be able to send aliases + event_auth.check( + RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False, + ) + + # No particular checks are done on the state key. + event_auth.check( + RoomVersions.V6, + _alias_event(creator, state_key=""), auth_events, do_sig_check=False, ) + event_auth.check( + RoomVersions.V6, + _alias_event(creator, state_key="test.com"), + auth_events, + do_sig_check=False, + ) + + # Per standard auth rules, the member must be in the room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False, + ) + + def test_msc2209(self): + """ + Notifications power levels get checked due to MSC2209. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.power_levels", ""): _power_levels_event( + creator, {"state_default": "30", "users": {pleb: "30"}} + ), + ("m.room.member", pleb): _join_event(pleb), + } + + # pleb should be able to modify the notifications power level. + event_auth.check( + RoomVersions.V1, + _power_levels_event(pleb, {"notifications": {"room": 100}}), + auth_events, + do_sig_check=False, + ) + + # But an MSC2209 room rejects this change. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _power_levels_event(pleb, {"notifications": {"room": 100}}), + auth_events, + do_sig_check=False, + ) # helpers for making events @@ -97,7 +199,7 @@ TEST_ROOM_ID = "!test:room" def _create_event(user_id): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -109,7 +211,7 @@ def _create_event(user_id): def _join_event(user_id): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -122,7 +224,7 @@ def _join_event(user_id): def _power_levels_event(sender, content): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -134,8 +236,21 @@ def _power_levels_event(sender, content): ) +def _alias_event(sender, **kwargs): + data = { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.aliases", + "sender": sender, + "state_key": get_domain_from_id(sender), + "content": {"aliases": []}, + } + data.update(**kwargs) + return make_event_from_dict(data) + + def _random_state_event(sender): - return FrozenEvent( + return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), diff --git a/tests/test_federation.py b/tests/test_federation.py index a73f18f88e..c662195eec 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,17 +1,18 @@ from mock import Mock -from twisted.internet.defer import maybeDeferred, succeed +from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID from synapse.util import Clock +from synapse.util.retryutils import NotRetryingDestination from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver -class MessageAcceptTests(unittest.TestCase): +class MessageAcceptTests(unittest.HomeserverTestCase): def setUp(self): self.http_client = Mock() @@ -27,20 +28,25 @@ class MessageAcceptTests(unittest.TestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = room_creator.create_room( - our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + room_deferred = ensureDeferred( + room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) ) self.reactor.advance(0.1) - self.room_id = self.successResultOf(room)["room_id"] + self.room_id = self.successResultOf(room_deferred)[0]["room_id"] + + self.store = self.homeserver.get_datastore() # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] - join_event = FrozenEvent( + join_event = make_event_from_dict( { "room_id": self.room_id, "sender": "@baduser:test.serv", @@ -58,15 +64,19 @@ class MessageAcceptTests(unittest.TestCase): ) self.handler = self.homeserver.get_handlers().federation_handler - self.handler.do_auth = lambda *a, **b: succeed(True) + self.handler.do_auth = lambda origin, event, context, auth_events: succeed( + context + ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus ) # Send the join, it should return None (which is not an error) - d = self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) ) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) @@ -74,9 +84,7 @@ class MessageAcceptTests(unittest.TestCase): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0], "$join:test.serv", ) @@ -96,13 +104,11 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0] # Now lie about an event - lying_event = FrozenEvent( + lying_event = make_event_from_dict( { "room_id": self.room_id, "sender": "@baduser:test.serv", @@ -118,8 +124,10 @@ class MessageAcceptTests(unittest.TestCase): ) with LoggingContext(request="lying_event"): - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) ) # Step the reactor, so the database fetches come back @@ -136,7 +144,121 @@ class MessageAcceptTests(unittest.TestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id - ) + extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + + def test_retry_device_list_resync(self): + """Tests that device lists are marked as stale if they couldn't be synced, and + that stale device lists are retried periodically. + """ + remote_user_id = "@john:test_remote" + remote_origin = "test_remote" + + # Track the number of attempts to resync the user's device list. + self.resync_attempts = 0 + + # When this function is called, increment the number of resync attempts (only if + # we're querying devices for the right user ID), then raise a + # NotRetryingDestination error to fail the resync gracefully. + def query_user_devices(destination, user_id): + if user_id == remote_user_id: + self.resync_attempts += 1 + + raise NotRetryingDestination(0, 0, destination) + + # Register the mock on the federation client. + federation_client = self.homeserver.get_federation_client() + federation_client.query_user_devices = Mock(side_effect=query_user_devices) + + # Register a mock on the store so that the incoming update doesn't fail because + # we don't share a room with the user. + store = self.homeserver.get_datastore() + store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) + + # Manually inject a fake device list update. We need this update to include at + # least one prev_id so that the user's device list will need to be retried. + device_list_updater = self.homeserver.get_device_handler().device_list_updater + self.get_success( + device_list_updater.incoming_device_list_update( + origin=remote_origin, + edu_content={ + "deleted": False, + "device_display_name": "Mobile", + "device_id": "QBUAZIFURK", + "prev_id": [5], + "stream_id": 6, + "user_id": remote_user_id, + }, + ) + ) + + # Check that there was one resync attempt. + self.assertEqual(self.resync_attempts, 1) + + # Check that the resync attempt failed and caused the user's device list to be + # marked as stale. + need_resync = self.get_success( + store.get_user_ids_requiring_device_list_resync() + ) + self.assertIn(remote_user_id, need_resync) + + # Check that waiting for 30 seconds caused Synapse to retry resyncing the device + # list. + self.reactor.advance(30) + self.assertEqual(self.resync_attempts, 2) + + def test_cross_signing_keys_retry(self): + """Tests that resyncing a device list correctly processes cross-signing keys from + the remote server. + """ + remote_user_id = "@john:test_remote" + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + # Register mock device list retrieval on the federation client. + federation_client = self.homeserver.get_federation_client() + federation_client.query_user_devices = Mock( + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) + + # Resync the device list. + device_handler = self.homeserver.get_device_handler() + self.get_success( + device_handler.device_list_updater.user_device_resync(remote_user_id), + ) + + # Retrieve the cross-signing keys for this user. + keys = self.get_success( + self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), + ) + self.assertTrue(remote_user_id in keys) + + # Check that the master key is the one returned by the mock. + master_key = keys[remote_user_id]["master"] + self.assertEqual(len(master_key["keys"]), 1) + self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) + self.assertTrue(remote_master_key in master_key["keys"].values()) + + # Check that the self-signing key is the one returned by the mock. + self_signing_key = keys[remote_user_id]["self_signing"] + self.assertEqual(len(self_signing_key["keys"]), 1) + self.assertTrue( + "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), + ) + self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values()) diff --git a/tests/test_mau.py b/tests/test_mau.py index 1fbe0d51ff..49667ed7f4 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,40 +17,44 @@ import json -from mock import Mock - from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.rest.client.v2_alpha import register, sync from tests import unittest +from tests.unittest import override_config +from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): servlets = [register.register_servlets, sync.register_servlets] - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") - self.hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() + config.update( + { + "registrations_require_3pid": [], + "limit_usage_by_mau": True, + "max_mau_value": 2, + "mau_trial_days": 0, + "server_notices": { + "system_mxid_localpart": "server", + "room_name": "Test Server Notice Room", + }, + } ) - self.store = self.hs.get_datastore() + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - self.hs.config.registrations_require_3pid = [] - self.hs.config.enable_registration_captcha = False - self.hs.config.recaptcha_public_key = [] + return config - self.hs.config.limit_usage_by_mau = True - self.hs.config.hs_disabled = False - self.hs.config.max_mau_value = 2 - self.hs.config.mau_trial_days = 0 - self.hs.config.server_notices_mxid = "@server:red" - self.hs.config.server_notices_mxid_display_name = None - self.hs.config.server_notices_mxid_avatar_url = None - self.hs.config.server_notices_room_name = "Test Server Notice Room" - return self.hs + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated @@ -59,6 +63,9 @@ class TestMauLimit(unittest.HomeserverTestCase): token2 = self.create_user("kermit2") self.do_sync_for_user(token2) + # check we're testing what we think we are: there should be two active users + self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2) + # We've created and activated two users, we shouldn't be able to # register new users with self.assertRaises(SynapseError) as cm: @@ -78,7 +85,7 @@ class TestMauLimit(unittest.HomeserverTestCase): # Advance time by 31 days self.reactor.advance(31 * 24 * 60 * 60) - self.store.reap_monthly_active_users() + self.get_success(self.store.reap_monthly_active_users()) self.reactor.advance(0) @@ -86,9 +93,8 @@ class TestMauLimit(unittest.HomeserverTestCase): token3 = self.create_user("kermit3") self.do_sync_for_user(token3) + @override_config({"mau_trial_days": 1}) def test_trial_delay(self): - self.hs.config.mau_trial_days = 1 - # We should be able to register more than the limit initially token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -120,6 +126,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + @override_config({"mau_trial_days": 1}) def test_trial_users_cant_come_back(self): self.hs.config.mau_trial_days = 1 @@ -140,8 +147,7 @@ class TestMauLimit(unittest.HomeserverTestCase): # Advance by 2 months so everyone falls out of MAU self.reactor.advance(60 * 24 * 60 * 60) - self.store.reap_monthly_active_users() - self.reactor.advance(0) + self.get_success(self.store.reap_monthly_active_users()) # We can create as many new users as we want token4 = self.create_user("kermit4") @@ -168,11 +174,11 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + @override_config( + # max_mau_value should not matter + {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True} + ) def test_tracked_but_not_limited(self): - self.hs.config.max_mau_value = 1 # should not matter - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - # Simply being able to create 2 users indicates that the # limit was not reached. token1 = self.create_user("kermit1") diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 270f853d60..f5f63d8ed6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.descriptors import Cache from tests import unittest @@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache = Cache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py new file mode 100644 index 0000000000..7657bddea5 --- /dev/null +++ b/tests/test_phone_home.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import resource + +import mock + +from synapse.app.homeserver import phone_stats_home + +from tests.unittest import HomeserverTestCase + + +class PhoneHomeStatsTestCase(HomeserverTestCase): + def test_performance_frozen_clock(self): + """ + If time doesn't move, don't error out. + """ + past_stats = [ + (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)) + ] + stats = {} + self.get_success(phone_stats_home(self.hs, stats, past_stats)) + self.assertEqual(stats["cpu_average"], 0) + + def test_performance_100(self): + """ + 1 second of usage over 1 second is 100% CPU usage. + """ + real_res = resource.getrusage(resource.RUSAGE_SELF) + old_resource = mock.Mock(spec=real_res) + old_resource.ru_utime = real_res.ru_utime - 1 + old_resource.ru_stime = real_res.ru_stime + old_resource.ru_maxrss = real_res.ru_maxrss + + past_stats = [(self.hs.get_clock().time(), old_resource)] + stats = {} + self.reactor.advance(1) + self.get_success(phone_stats_home(self.hs, stats, past_stats)) + self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5) diff --git a/tests/test_server.py b/tests/test_server.py index 98fef21d55..e9a43b1e45 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -23,8 +23,13 @@ from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from synapse.api.errors import Codes, SynapseError -from synapse.http.server import JsonResource +from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.http.server import ( + DirectServeResource, + JsonResource, + OptionsResource, + wrap_html_request_handler, +) from synapse.http.site import SynapseSite, logger from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock @@ -164,6 +169,157 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") +class OptionsResourceTests(unittest.TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + class DummyResource(Resource): + isLeaf = True + + def render(self, request): + return request.path + + # Setup a resource with some children. + self.resource = OptionsResource() + self.resource.putChild(b"res", DummyResource()) + + def _make_request(self, method, path): + """Create a request from the method/path and return a channel with the response.""" + request, channel = make_request(self.reactor, method, path, shorthand=False) + request.prepath = [] # This doesn't get set properly by make_request. + + # Create a site and query for the resource. + site = SynapseSite("test", "site_tag", {}, self.resource, "1.0") + request.site = site + resource = site.getResourceFor(request) + + # Finally, render the resource and return the channel. + render(request, resource, self.reactor) + return channel + + def test_unknown_options_request(self): + """An OPTIONS requests to an unknown URL still returns 200 OK.""" + channel = self._make_request(b"OPTIONS", b"/foo/") + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.result["body"], b"{}") + + # Ensure the correct CORS headers have been added + self.assertTrue( + channel.headers.hasHeader(b"Access-Control-Allow-Origin"), + "has CORS Origin header", + ) + self.assertTrue( + channel.headers.hasHeader(b"Access-Control-Allow-Methods"), + "has CORS Methods header", + ) + self.assertTrue( + channel.headers.hasHeader(b"Access-Control-Allow-Headers"), + "has CORS Headers header", + ) + + def test_known_options_request(self): + """An OPTIONS requests to an known URL still returns 200 OK.""" + channel = self._make_request(b"OPTIONS", b"/res/") + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.result["body"], b"{}") + + # Ensure the correct CORS headers have been added + self.assertTrue( + channel.headers.hasHeader(b"Access-Control-Allow-Origin"), + "has CORS Origin header", + ) + self.assertTrue( + channel.headers.hasHeader(b"Access-Control-Allow-Methods"), + "has CORS Methods header", + ) + self.assertTrue( + channel.headers.hasHeader(b"Access-Control-Allow-Headers"), + "has CORS Headers header", + ) + + def test_unknown_request(self): + """A non-OPTIONS request to an unknown URL should 404.""" + channel = self._make_request(b"GET", b"/foo/") + self.assertEqual(channel.result["code"], b"404") + + def test_known_request(self): + """A non-OPTIONS request to an known URL should query the proper resource.""" + channel = self._make_request(b"GET", b"/res/") + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.result["body"], b"/res/") + + +class WrapHtmlRequestHandlerTests(unittest.TestCase): + class TestResource(DirectServeResource): + callback = None + + @wrap_html_request_handler + async def _async_render_GET(self, request): + return await self.callback(request) + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + def test_good_response(self): + def callback(request): + request.write(b"response") + request.finish() + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"GET", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + body = channel.result["body"] + self.assertEqual(body, b"response") + + def test_redirect_exception(self): + """ + If the callback raises a RedirectException, it is turned into a 30x + with the right location. + """ + + def callback(request, **kwargs): + raise RedirectException(b"/look/an/eagle", 301) + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"GET", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"301") + headers = channel.result["headers"] + location_headers = [v for k, v in headers if k == b"Location"] + self.assertEqual(location_headers, [b"/look/an/eagle"]) + + def test_redirect_exception_with_cookie(self): + """ + If the callback raises a RedirectException which sets a cookie, that is + returned too + """ + + def callback(request, **kwargs): + e = RedirectException(b"/no/over/there", 304) + e.cookies.append(b"session=yespls") + raise e + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"GET", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"304") + headers = channel.result["headers"] + location_headers = [v for k, v in headers if k == b"Location"] + self.assertEqual(location_headers, [b"/no/over/there"]) + cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] + self.assertEqual(cookies_headers, [b"session=yespls"]) + + class SiteTestCase(unittest.HomeserverTestCase): def test_lose_connection(self): """ diff --git a/tests/test_state.py b/tests/test_state.py index 610ec9fb46..66f22f6813 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -20,7 +20,8 @@ from twisted.internet import defer from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler from tests import unittest @@ -65,7 +66,7 @@ def create_event( d.update(kwargs) - event = FrozenEvent(d) + event = make_event_from_dict(d) return event @@ -118,7 +119,7 @@ class StateGroupStore(object): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group - def get_room_version(self, room_id): + def get_room_version_id(self, room_id): return RoomVersions.V1.identifier @@ -158,10 +159,12 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() + storage = Mock(main=self.store, state=self.store) hs = Mock( spec_set=[ "config", "get_datastore", + "get_storage", "get_auth", "get_state_handler", "get_clock", @@ -174,6 +177,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) + hs.get_storage.return_value = storage self.state = StateHandler(hs) self.event_id = 0 @@ -195,16 +199,22 @@ class StateTestCase(unittest.TestCase): self.store.register_events(graph.walk()) - context_store = {} + context_store = {} # type: dict[str, EventContext] for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + ctx_c = context_store["C"] + ctx_d = context_store["D"] + + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertEqual(2, len(prev_state_ids)) + self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) + self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) + @defer.inlineCallbacks def test_branch_basic_conflict(self): graph = Graph( @@ -238,11 +248,16 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + # C ends up winning the resolution between B and C - self.assertSetEqual( - {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} - ) + ctx_c = context_store["C"] + ctx_d = context_store["D"] + + prev_state_ids = yield ctx_d.get_prev_state_ids() + self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) + + self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) + self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @defer.inlineCallbacks def test_branch_have_banned_conflict(self): @@ -289,11 +304,16 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) + # C ends up winning the resolution between C and D because bans win over other + # changes - self.assertSetEqual( - {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} - ) + ctx_c = context_store["C"] + ctx_e = context_store["E"] + + prev_state_ids = yield ctx_e.get_prev_state_ids() + self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) + self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) + self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @defer.inlineCallbacks def test_branch_have_perms_conflict(self): @@ -357,11 +377,17 @@ class StateTestCase(unittest.TestCase): self.store.register_event_context(event, context) context_store[event.event_id] = context - prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) + # B ends up winning the resolution between B and C because power levels + # win over other changes. - self.assertSetEqual( - {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} - ) + ctx_b = context_store["B"] + ctx_d = context_store["D"] + + prev_state_ids = yield ctx_d.get_prev_state_ids() + self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) + + self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) + self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) def _add_depths(self, nodes, edges): def _get_depth(ev): @@ -387,13 +413,16 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - current_state_ids = yield context.get_current_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() + self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - self.assertEqual( - set(e.event_id for e in old_state), set(current_state_ids.values()) + current_state_ids = yield context.get_current_state_ids() + self.assertCountEqual( + (e.event_id for e in old_state), current_state_ids.values() ) - self.assertIsNotNone(context.state_group) + self.assertIsNotNone(context.state_group_before_event) + self.assertEqual(context.state_group_before_event, context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -407,12 +436,19 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() + self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - self.assertEqual( - set(e.event_id for e in old_state), set(prev_state_ids.values()) + current_state_ids = yield context.get_current_state_ids() + self.assertCountEqual( + (e.event_id for e in old_state + [event]), current_state_ids.values() ) + self.assertIsNotNone(context.state_group_before_event) + self.assertNotEqual(context.state_group_before_event, context.state_group) + self.assertEqual(context.state_group_before_event, context.prev_group) + self.assertEqual({("state", ""): event.event_id}, context.delta_ids) + @defer.inlineCallbacks def test_trivial_annotate_message(self): prev_event_id = "prev_event_id" @@ -437,10 +473,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual( - set([e.event_id for e in old_state]), set(current_state_ids.values()) + {e.event_id for e in old_state}, set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -469,11 +505,9 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() - self.assertEqual( - set([e.event_id for e in old_state]), set(prev_state_ids.values()) - ) + self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) self.assertIsNotNone(context.state_group) @@ -510,7 +544,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -552,7 +586,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -607,7 +641,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -635,7 +669,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 52739fbabc..5c2817cf28 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -28,6 +28,21 @@ from tests import unittest class TermsTestCase(unittest.HomeserverTestCase): servlets = [register_servlets] + def default_config(self): + config = super().default_config() + config.update( + { + "public_baseurl": "https://example.org/", + "user_consent": { + "version": "1.0", + "policy_name": "My Cool Privacy Policy", + "template_dir": "/", + "require_at_registration": True, + }, + } + ) + return config + def prepare(self, reactor, clock, hs): self.clock = MemoryReactorClock() self.hs_clock = Clock(self.clock) @@ -35,19 +50,11 @@ class TermsTestCase(unittest.HomeserverTestCase): self.registration_handler = Mock() self.auth_handler = Mock() self.device_handler = Mock() - hs.config.enable_registration = True - hs.config.registrations_require_3pid = [] - hs.config.auto_join_rooms = [] - hs.config.enable_registration_captcha = False def test_ui_auth(self): - self.hs.config.user_consent_at_registration = True - self.hs.config.user_consent_policy_name = "My Cool Privacy Policy" - self.hs.config.public_baseurl = "https://example.org/" - self.hs.config.user_consent_version = "1.0" - # Do a UI auth request - request, channel = self.make_request(b"POST", self.url, b"{}") + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/test_types.py b/tests/test_types.py index 9ab5f829b0..480bea1bdc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest -from tests.utils import TestHomeServer -mock_homeserver = TestHomeServer(hostname="my.domain") - -class UserIDTestCase(unittest.TestCase): +class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): - user = UserID.from_string("@1234abcd:my.domain") + user = UserID.from_string("@1234abcd:test") self.assertEquals("1234abcd", user.localpart) - self.assertEquals("my.domain", user.domain) - self.assertEquals(True, mock_homeserver.is_mine(user)) + self.assertEquals("test", user.domain) + self.assertEquals(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase): self.assertTrue(userA != userB) -class RoomAliasTestCase(unittest.TestCase): +class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): - room = RoomAlias.from_string("#channel:my.domain") + room = RoomAlias.from_string("#channel:test") self.assertEquals("channel", room.localpart) - self.assertEquals("my.domain", room.domain) - self.assertEquals(True, mock_homeserver.is_mine(room)) + self.assertEquals("test", room.domain) + self.assertEquals(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") @@ -78,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase): self.fail("Parsing '%s' should raise exception" % id_string) except SynapseError as exc: self.assertEqual(400, exc.code) - self.assertEqual("M_UNKNOWN", exc.errcode) + self.assertEqual("M_INVALID_PARAM", exc.errcode) class MapUsernameTestCase(unittest.TestCase): diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a7310cf12a..7b345b03bb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +17,22 @@ """ Utilities for running the unit tests """ +from typing import Awaitable, TypeVar + +TV = TypeVar("TV") + + +def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: + """Get the result from an Awaitable which should have completed + + Asserts that the given awaitable has a result ready, and returns its value + """ + i = awaitable.__await__() + try: + next(i) + except StopIteration as e: + # awaitable returned a result + return e.value + + # if next didn't raise, the awaitable hasn't completed. + raise Exception("awaitable has not yet completed") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py new file mode 100644 index 0000000000..431e9f8e5e --- /dev/null +++ b/tests/test_utils/event_injection.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import synapse.server +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Collection + +from tests.test_utils import get_awaitable_result + + +""" +Utility functions for poking events into the storage of the server under test. +""" + + +def inject_member_event( + hs: synapse.server.HomeServer, + room_id: str, + sender: str, + membership: str, + target: Optional[str] = None, + extra_content: Optional[dict] = None, + **kwargs +) -> EventBase: + """Inject a membership event into a room.""" + if target is None: + target = sender + + content = {"membership": membership} + if extra_content: + content.update(extra_content) + + return inject_event( + hs, + room_id=room_id, + type=EventTypes.Member, + sender=sender, + state_key=target, + content=content, + **kwargs + ) + + +def inject_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> EventBase: + """Inject a generic event into a room + + Args: + hs: the homeserver under test + room_version: the version of the room we're inserting into. + if not specified, will be looked up + prev_event_ids: prev_events for the event. If not specified, will be looked up + kwargs: fields for the event to be created + """ + test_reactor = hs.get_reactor() + + event, context = create_event(hs, room_version, prev_event_ids, **kwargs) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event + + +def create_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> Tuple[EventBase, EventContext]: + test_reactor = hs.get_reactor() + + if room_version is None: + d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) + test_reactor.advance(0) + room_version = get_awaitable_result(d) + + builder = hs.get_event_builder_factory().for_room_version( + KNOWN_ROOM_VERSIONS[room_version], kwargs + ) + d = hs.get_event_creation_handler().create_new_client_event( + builder, prev_event_ids=prev_event_ids + ) + test_reactor.advance(0) + event, context = get_awaitable_result(d) + + return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 18f1a0035d..f7381b2885 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -14,6 +14,8 @@ # limitations under the License. import logging +from mock import Mock + from twisted.internet import defer from twisted.internet.defer import succeed @@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") @@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) # the result should be 5 redacted events, and 5 unredacted events. @@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # ... and the filtering happens. filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) for i in range(0, len(events_to_filter)): @@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield self.event_creation_handler.create_new_client_event( builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): logger.info("Starting filtering") start = time.time() + + storage = Mock() + storage.main = test_store + storage.state = test_store + filtered = yield filter_events_for_server( test_store, "test_server", events_to_filter ) diff --git a/tests/unittest.py b/tests/unittest.py index 561cebc223..6b6f224e9c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,29 +14,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import hashlib import hmac +import inspect import logging import time +from typing import Optional, Tuple, Type, TypeVar, Union from mock import Mock from canonicaljson import json -from twisted.internet.defer import Deferred, succeed +from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest -from synapse.logging.context import LoggingContext +from synapse.http.site import SynapseRequest, SynapseSite +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + current_context, + set_current_context, +) from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester - -from tests.server import get_clock, make_request, render, setup_test_homeserver +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests.server import ( + FakeChannel, + get_clock, + make_request, + render, + setup_test_homeserver, +) +from tests.test_utils import event_injection from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -64,6 +83,9 @@ def around(target): return _around +T = TypeVar("T") + + class TestCase(unittest.TestCase): """A subclass of twisted.trial's TestCase which looks for 'loglevel' attributes on both itself and its individual test methods, to override the @@ -80,10 +102,10 @@ class TestCase(unittest.TestCase): def setUp(orig): # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): self.fail( "Test starting with non-sentinel logging context %s" - % (LoggingContext.current_context(),) + % (current_context(),) ) old_level = logging.getLogger().level @@ -105,7 +127,7 @@ class TestCase(unittest.TestCase): # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) gc.collect() - LoggingContext.set_current_context(LoggingContext.sentinel) + set_current_context(SENTINEL_CONTEXT) return ret @@ -203,6 +225,15 @@ class HomeserverTestCase(TestCase): # Register the resources self.resource = self.create_test_json_resource() + # create a site to wrap the resource. + self.site = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="test", + config={}, + resource=self.resource, + server_version_string="1", + ) + from tests.rest.client.v1.utils import RestHelper self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) @@ -285,14 +316,11 @@ class HomeserverTestCase(TestCase): return resource - def default_config(self, name="test"): + def default_config(self): """ Get a default HomeServer config dict. - - Args: - name (str): The homeserver name/domain. """ - config = default_config(name) + config = default_config("test") # apply any additional config which was specified via the override_config # decorator. @@ -318,14 +346,14 @@ class HomeserverTestCase(TestCase): def make_request( self, - method, - path, - content=b"", - access_token=None, - request=SynapseRequest, - shorthand=True, - federation_auth_origin=None, - ): + method: Union[bytes, str], + path: Union[bytes, str], + content: Union[bytes, dict] = b"", + access_token: Optional[str] = None, + request: Type[T] = SynapseRequest, + shorthand: bool = True, + federation_auth_origin: str = None, + ) -> Tuple[T, FakeChannel]: """ Create a SynapseRequest at the path using the method and containing the given content. @@ -392,13 +420,17 @@ class HomeserverTestCase(TestCase): config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj + async def run_bg_updates(): + with LoggingContext("run_bg_updates", request="run_bg_updates-1"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() - # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not self.get_success(stor.has_completed_background_updates()): - self.get_success(stor.do_next_background_update(1)) + # Run the database background updates, when running against "master". + if hs.__class__.__name__ == "TestHomeServer": + self.get_success(run_bg_updates()) return hs @@ -409,6 +441,8 @@ class HomeserverTestCase(TestCase): self.reactor.pump([by] * 100) def get_success(self, d, by=0.0): + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump(by=by) @@ -418,6 +452,8 @@ class HomeserverTestCase(TestCase): """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump() @@ -441,7 +477,7 @@ class HomeserverTestCase(TestCase): # Create the user request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") self.render(request) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -461,6 +497,7 @@ class HomeserverTestCase(TestCase): "password": password, "admin": admin, "mac": want_mac, + "inhibit_login": True, } ) request, channel = self.make_request( @@ -509,10 +546,6 @@ class HomeserverTestCase(TestCase): secrets = self.hs.get_secrets() requester = Requester(user, None, False, None, None) - prev_events_and_hashes = None - if prev_event_ids: - prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] - event, context = self.get_success( event_creator.create_event( requester, @@ -522,7 +555,7 @@ class HomeserverTestCase(TestCase): "sender": user.to_string(), "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, }, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, ) ) @@ -538,7 +571,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore()._simple_insert( + self.hs.get_datastore().db.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", @@ -559,6 +592,46 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 403, channel.result) + def inject_room_member(self, room: str, user: str, membership: Membership) -> None: + """ + Inject a membership event into a room. + + Deprecated: use event_injection.inject_room_member directly + + Args: + room: Room ID to inject the event into. + user: MXID of the user to inject the membership for. + membership: The membership type. + """ + event_injection.inject_member_event(self.hs, room, user, membership) + + +class FederatingHomeserverTestCase(HomeserverTestCase): + """ + A federating homeserver that authenticates incoming requests as `other.example.com`. + """ + + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return succeed("other.example.com") + + ratelimiter = FederationRateLimiter( + clock, + FederationRateLimitConfig( + window_size=1, + sleep_limit=1, + sleep_msec=1, + reject_limit=1000, + concurrent_requests=1000, + ), + ) + federation_server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + return super().prepare(reactor, clock, homeserver) + def override_config(extra_config): """A decorator which can be applied to test functions to give additional HS config diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 5713870f48..4d2b9e0d64 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -22,8 +22,10 @@ from twisted.internet import defer, reactor from synapse.api.errors import SynapseError from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.util.caches import descriptors @@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase): with LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) return r def check_result(r): @@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d2.addCallback(check_result) # let the lookup complete @@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase): try: d = obj.fn(1) self.assertEqual( - LoggingContext.current_context(), LoggingContext.sentinel + current_context(), SENTINEL_CONTEXT, ) yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) # the cache should now be empty self.assertEqual(len(obj.fn.cache.cache), 0) @@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) return d1 @@ -310,14 +312,14 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.return_value = ["spam", "eggs"] r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = ["chips"] r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() @@ -325,9 +327,9 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(len(obj.fn.cache.cache), 3) r = obj.fn(1, 2) - self.assertEqual(r, ["spam", "eggs"]) + self.assertEqual(r.result, ["spam", "eggs"]) r = obj.fn(1, 3) - self.assertEqual(r, ["chips"]) + self.assertEqual(r.result, ["chips"]) obj.mock.assert_not_called() def test_cache_iterable_with_sync_exception(self): @@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1", inlineCallbacks=True) def list_fn(self, args1, arg2): - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" # we want this to behave like an asynchronous function yield run_on_reactor() - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" return self.mock(args1, arg2) with LoggingContext() as c1: @@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index f60918069a..17fd86d02d 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -16,7 +16,12 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred from twisted.internet.task import Clock -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + PreserveLoggingContext, + current_context, +) from synapse.util.async_helpers import timeout_deferred from tests.unittest import TestCase @@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), + current_context(), context_one, "errback %s run in unexpected logcontext %s" - % (deferred_name, LoggingContext.current_context()), + % (deferred_name, current_context()), ) return res @@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase): original_deferred.addErrback(errback, "orig") timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) self.assertNoResult(timing_out_d) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) timing_out_d.addErrback(errback, "timingout") self.clock.pump((1.0,)) @@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase): blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) self.failureResultOf(timing_out_d, defer.TimeoutError) - self.assertIs(LoggingContext.current_context(), context_one) + self.assertIs(current_context(), context_one) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 50bc7702d2..49ffeebd0e 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -21,7 +21,7 @@ from tests.utils import MockClock from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py new file mode 100644 index 0000000000..0ab0a91483 --- /dev/null +++ b/tests/util/test_itertools.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.util.iterutils import chunk_seq + +from tests.unittest import TestCase + + +class ChunkSeqTests(TestCase): + def test_short_seq(self): + parts = chunk_seq("123", 8) + + self.assertEqual( + list(parts), ["123"], + ) + + def test_long_seq(self): + parts = chunk_seq("abcdefghijklmnop", 8) + + self.assertEqual( + list(parts), ["abcdefgh", "ijklmnop"], + ) + + def test_uneven_parts(self): + parts = chunk_seq("abcdefghijklmnop", 5) + + self.assertEqual( + list(parts), ["abcde", "fghij", "klmno", "p"], + ) + + def test_empty_input(self): + parts = chunk_seq([], 5) + + self.assertEqual( + list(parts), [], + ) diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 0ec8ef90ce..ca3858b184 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -19,7 +19,7 @@ from six.moves import range from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -45,6 +45,38 @@ class LinearizerTestCase(unittest.TestCase): with (yield d2): pass + @defer.inlineCallbacks + def test_linearizer_is_queued(self): + linearizer = Linearizer() + + key = object() + + d1 = linearizer.queue(key) + cm1 = yield d1 + + # Since d1 gets called immediately, "is_queued" should return false. + self.assertFalse(linearizer.is_queued(key)) + + d2 = linearizer.queue(key) + self.assertFalse(d2.called) + + # Now d2 is queued up behind successful completion of cm1 + self.assertTrue(linearizer.is_queued(key)) + + with cm1: + self.assertFalse(d2.called) + + # cm1 still not done, so d2 still queued. + self.assertTrue(linearizer.is_queued(key)) + + # And now d2 is called and nothing is in the queue again + self.assertFalse(linearizer.is_queued(key)) + + with (yield d2): + self.assertFalse(linearizer.is_queued(key)) + + self.assertFalse(linearizer.is_queued(key)) + def test_lots_of_queued_things(self): # we have one slow thing, and lots of fast things queued up behind it. # it should *not* explode the stack. @@ -54,11 +86,11 @@ class LinearizerTestCase(unittest.TestCase): def func(i, sleep=False): with LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) if sleep: yield Clock(reactor).sleep(0) - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) func(0, sleep=True) for i in range(1, 100): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8b8455c8b7..95301c013c 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -2,8 +2,10 @@ import twisted.python.failure from twisted.internet import defer, reactor from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, nested_logging_context, run_in_background, @@ -15,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(LoggingContext.current_context().request, value) + self.assertEquals(current_context().request, value) def test_with_context(self): with LoggingContext() as context_one: @@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") def _test_run_in_background(self, function): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() callback_completed = [False] @@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase): # make sure that the context was reset before it got thrown back # into the reactor try: - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) d2.callback(None) except BaseException: d2.errback(twisted.python.failure.Failure()) @@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase): async def testfunc(): self._check_test_key("one") d = Clock(reactor).sleep(0) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("one") @@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) return d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_chained_deferreds(self): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -179,6 +181,30 @@ class LoggingContextTestCase(unittest.TestCase): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.request, "foo-bar") + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_await(self): + # an async function which retuns an incomplete coroutine, but doesn't + # follow the synapse rules. + + async def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + await d + + sentinel_context = current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 786947375d..0adb2174af 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache from .. import unittest -class LruCacheTestCase(unittest.TestCase): +class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(len(cache), 0) -class LruCacheCallbacksTestCase(unittest.TestCase): +class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): m = Mock() cache = LruCache(1) @@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m3.call_count, 1) -class LruCacheSizedTestCase(unittest.TestCase): +class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py deleted file mode 100644 index 1a44f72425..0000000000 --- a/tests/util/test_snapshot_cache.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet.defer import Deferred - -from synapse.util.caches.snapshot_cache import SnapshotCache - -from .. import unittest - - -class SnapshotCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = SnapshotCache() - self.cache.DURATION_MS = 1 - - def test_get_set(self): - # Check that getting a missing key returns None - self.assertEquals(self.cache.get(0, "key"), None) - - # Check that setting a key with a deferred returns - # a deferred that resolves when the initial deferred does - d = Deferred() - set_result = self.cache.set(0, "key", d) - self.assertIsNotNone(set_result) - self.assertFalse(set_result.called) - - # Check that getting the key before the deferred has resolved - # returns a deferred that resolves when the initial deferred does. - get_result_at_10 = self.cache.get(10, "key") - self.assertIsNotNone(get_result_at_10) - self.assertFalse(get_result_at_10.called) - - # Check that the returned deferreds resolve when the initial deferred - # does. - d.callback("v") - self.assertTrue(set_result.called) - self.assertTrue(get_result_at_10.called) - - # Check that getting the key after the deferred has resolved - # before the cache expires returns a resolved deferred. - get_result_at_11 = self.cache.get(11, "key") - self.assertIsNotNone(get_result_at_11) - if isinstance(get_result_at_11, Deferred): - # The cache may return the actual result rather than a deferred - self.assertTrue(get_result_at_11.called) - - # Check that getting the key after the deferred has resolved - # after the cache expires returns None - get_result_at_12 = self.cache.get(12, "key") - self.assertIsNone(get_result_at_12) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index f2be63706b..13b753e367 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,11 +1,9 @@ -from mock import patch - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest -class StreamChangeCacheTests(unittest.TestCase): +class StreamChangeCacheTests(unittest.HomeserverTestCase): """ Tests for StreamChangeCache. """ @@ -28,26 +26,33 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 6) cache.entity_has_changed("bar@baz.net", 7) + # also test multiple things changing on the same stream ID + cache.entity_has_changed("user2@foo.com", 8) + cache.entity_has_changed("bar2@baz.net", 8) + # If it's been changed after that stream position, return True self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("user2@foo.com", 4)) # If it's been changed at that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 8)) # If there's no changes after that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 9)) # If the entity does not exist, return False. - self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + self.assertFalse(cache.has_entity_changed("not@here.website", 9)) # If we request before the stream cache's earliest known position, # return True, whether it's a known entity or not. self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) - def test_has_entity_changed_pops_off_start(self): + def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. @@ -64,11 +69,20 @@ class StreamChangeCacheTests(unittest.TestCase): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) + # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( - set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key) + {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) + self.assertEqual( + cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) def test_get_all_entities_changed(self): """ @@ -80,18 +94,52 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - self.assertEqual( - cache.get_all_entities_changed(1), - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], - ) - self.assertEqual( - cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] - ) + r = cache.get_all_entities_changed(1) + + # either of these are valid + ok1 = [ + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + "user@elsewhere.org", + ] + ok2 = [ + "user@foo.com", + "anotheruser@foo.com", + "bar@baz.net", + "user@elsewhere.org", + ] + self.assertTrue(r == ok1 or r == ok2) + + r = cache.get_all_entities_changed(2) + self.assertTrue(r == ok1[1:] or r == ok2[1:]) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) self.assertEqual(cache.get_all_entities_changed(0), None) + # ... later, things gest more updates + cache.entity_has_changed("user@foo.com", 5) + cache.entity_has_changed("bar@baz.net", 5) + cache.entity_has_changed("anotheruser@foo.com", 6) + + ok1 = [ + "user@elsewhere.org", + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + ] + ok2 = [ + "user@elsewhere.org", + "bar@baz.net", + "user@foo.com", + "anotheruser@foo.com", + ] + r = cache.get_all_entities_changed(3) + self.assertTrue(r == ok1 or r == ok2) + def test_has_any_entity_changed(self): """ StreamChangeCache.has_any_entity_changed will return True if any @@ -137,7 +185,7 @@ class StreamChangeCacheTests(unittest.TestCase): cache.get_entities_changed( ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries mid-way through the stream, but include one @@ -153,7 +201,7 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=2, ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries, but before the first known point. We will get @@ -168,21 +216,13 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=0, ), - set( - [ - "user@foo.com", - "bar@baz.net", - "user@elsewhere.org", - "not@here.website", - ] - ), + {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), - set(["bar@baz.net"]), + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"}, ) def test_max_pos(self): diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py new file mode 100644 index 0000000000..4f4da29a98 --- /dev/null +++ b/tests/util/test_stringutils.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import SynapseError +from synapse.util.stringutils import assert_valid_client_secret + +from .. import unittest + + +class StringUtilsTestCase(unittest.TestCase): + def test_client_secret_regex(self): + """Ensure that client_secret does not contain illegal characters""" + good = [ + "abcde12345", + "ABCabc123", + "_--something==_", + "...--==-18913", + "8Dj2odd-e9asd.cd==_--ddas-secret-", + # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766 + # To be removed in a future release + "SECRET:1234567890", + ] + + bad = [ + "--+-/secret", + "\\dx--dsa288", + "", + "AAS//", + "asdj**", + ">X><Z<!!-)))", + "a@b.com", + ] + + for client_secret in good: + assert_valid_client_secret(client_secret) + + for client_secret in bad: + with self.assertRaises(SynapseError): + assert_valid_client_secret(client_secret) diff --git a/tests/utils.py b/tests/utils.py index 46ef2959f2..59c020a051 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,19 +30,16 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine -from synapse.storage.prepare_database import ( - _get_or_create_schema_state, - _setup_new_database, - prepare_database, -) +from synapse.storage.prepare_database import prepare_database from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. @@ -77,7 +74,10 @@ def setupdb(): db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) - cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.execute( + "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' " + "template=template0;" % (POSTGRES_BASE_DB,) + ) cur.close() db_conn.close() @@ -88,11 +88,7 @@ def setupdb(): host=POSTGRES_HOST, password=POSTGRES_PASSWORD, ) - cur = db_conn.cursor() - _get_or_create_schema_state(cur, db_engine) - _setup_new_database(cur, db_engine) - db_conn.commit() - cur.close() + prepare_database(db_conn, db_engine, None) db_conn.close() def _cleanup(): @@ -117,6 +113,7 @@ def default_config(name, parse=False): """ config_dict = { "server_name": name, + "send_federation": False, "media_store_path": "media", "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config @@ -145,7 +142,6 @@ def default_config(name, parse=False): "limit_usage_by_mau": False, "hs_disabled": False, "hs_disabled_message": "", - "hs_disabled_limit_type": "", "max_mau_value": 50, "mau_trial_days": 0, "mau_stats_only": False, @@ -171,6 +167,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, + "caches": {"global_factor": 1}, } if parse: @@ -185,7 +182,6 @@ class TestHomeServer(HomeServer): DATASTORE_CLASS = DataStore -@defer.inlineCallbacks def setup_test_homeserver( cleanup_func, name="test", @@ -222,7 +218,7 @@ def setup_test_homeserver( if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex - config.database_config = { + database_config = { "name": "psycopg2", "args": { "database": test_db, @@ -234,12 +230,15 @@ def setup_test_homeserver( }, } else: - config.database_config = { + database_config = { "name": "sqlite3", "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - db_engine = create_engine(config.database_config) + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() @@ -259,39 +258,30 @@ def setup_test_homeserver( cur.close() db_conn.close() - # we need to configure the connection pool to run the on_new_connection - # function, so that we can test code that uses custom sqlite functions - # (like rank). - config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection - if datastore is None: hs = homeserverToUse( name, config=config, - db_config=config.database_config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, **kargs ) - # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to - # date db - if not isinstance(db_engine, PostgresEngine): - db_conn = hs.get_db_conn() - yield prepare_database(db_conn, db_engine, config) - db_conn.commit() - db_conn.close() + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] - else: # We need to do cleanup on PostgreSQL def cleanup(): import psycopg2 # Close all the db pools - hs.get_db_pool().close() + database._db_pool.close() dropped = False @@ -330,17 +320,12 @@ def setup_test_homeserver( # Register the cleanup hook cleanup_func(cleanup) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() else: hs = homeserverToUse( name, - db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, @@ -351,10 +336,15 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h - ) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash fed = kargs.get("resource_for_federation", None) if fed: @@ -463,7 +453,9 @@ class MockHttpResource(HttpServer): try: args = [urlparse.unquote(u) for u in matcher.groups()] - (code, response) = yield func(mock_request, *args) + (code, response) = yield defer.ensureDeferred( + func(mock_request, *args) + ) return code, response except CodeMessageException as e: return (e.code, cs_error(e.msg, code=e.errcode)) @@ -510,10 +502,10 @@ class MockClock(object): return self.time() * 1000 def call_later(self, delay, callback, *args, **kwargs): - current_context = LoggingContext.current_context() + ctx = current_context() def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context + set_current_context(ctx) callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] @@ -521,8 +513,8 @@ class MockClock(object): return t - def looping_call(self, function, interval): - self.loopers.append([function, interval / 1000.0, self.now]) + def looping_call(self, function, interval, *args, **kwargs): + self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -552,9 +544,9 @@ class MockClock(object): self.timers.append(t) for looped in self.loopers: - func, interval, last = looped + func, interval, last, args, kwargs = looped if last + interval < self.now: - func() + func(*args, **kwargs) looped[2] = self.now def advance_time_msec(self, ms): @@ -655,10 +647,18 @@ def create_room(hs, room_id, creator_id): creator_id (str) """ + persistence_store = hs.get_storage().persistence store = hs.get_datastore() event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() + yield store.store_room( + room_id=room_id, + room_creator_user_id=creator_id, + is_public=False, + room_version=RoomVersions.V1, + ) + builder = event_builder_factory.for_room_version( RoomVersions.V1, { @@ -672,4 +672,4 @@ def create_room(hs, room_id, creator_id): event, context = yield event_creation_handler.create_new_client_event(builder) - yield store.persist_event(event, context) + yield persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini index 1bce10a4ce..463a34d137 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = packaging, py35, py36, py37, check_codestyle, check_isort +envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort [base] basepython = python3.7 @@ -102,6 +102,15 @@ commands = {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} +[testenv:benchmark] +deps = + {[base]deps} + pyperf +setenv = + SYNAPSE_POSTGRES = 1 +commands = + python -m synmark {posargs:} + [testenv:packaging] skip_install=True deps = @@ -114,15 +123,17 @@ skip_install = True basepython = python3.6 deps = flake8 - black + flake8-comprehensions + black==19.10b0 # We pin so that our tests don't start failing on new releases of black. commands = python -m black --check --diff . - /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" + /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}" + {toxinidir}/scripts-dev/config-lint.sh [testenv:check_isort] skip_install = True deps = isort -commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests" +commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts" [testenv:check-newsfragment] skip_install = True @@ -157,16 +168,52 @@ commands= coverage html [testenv:mypy] -basepython = python3.7 skip_install = True deps = {[base]deps} - mypy + mypy==0.750 mypy-zope - typeshed env = MYPYPATH = stubs/ extras = all -commands = mypy --show-traceback \ +commands = mypy \ + synapse/api \ + synapse/appservice \ + synapse/config \ + synapse/event_auth.py \ + synapse/events/spamcheck.py \ + synapse/federation \ + synapse/handlers/auth.py \ + synapse/handlers/cas_handler.py \ + synapse/handlers/directory.py \ + synapse/handlers/oidc_handler.py \ + synapse/handlers/presence.py \ + synapse/handlers/room_member.py \ + synapse/handlers/room_member_worker.py \ + synapse/handlers/saml_handler.py \ + synapse/handlers/sync.py \ + synapse/handlers/ui_auth \ + synapse/http/server.py \ + synapse/http/site.py \ synapse/logging/ \ - synapse/config/ + synapse/metrics \ + synapse/module_api \ + synapse/push/pusherpool.py \ + synapse/push/push_rule_evaluator.py \ + synapse/replication \ + synapse/rest \ + synapse/spam_checker_api \ + synapse/storage/data_stores/main/ui_auth.py \ + synapse/storage/database.py \ + synapse/storage/engines \ + synapse/storage/util \ + synapse/streams \ + synapse/util/caches/stream_change_cache.py \ + tests/replication \ + tests/test_utils \ + tests/rest/client/v2_alpha/test_auth.py \ + tests/util/test_stream_change_cache.py + +# To find all folders that pass mypy you run: +# +# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print |