diff options
185 files changed, 3393 insertions, 1985 deletions
diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh index 7d0625fa86..478c8d639a 100755 --- a/.ci/scripts/test_old_deps.sh +++ b/.ci/scripts/test_old_deps.sh @@ -69,7 +69,7 @@ with open('pyproject.toml', 'w') as f: " python3 -c "$REMOVE_DEV_DEPENDENCIES" -pipx install poetry==1.1.12 +pipx install poetry==1.1.14 ~/.local/bin/poetry lock echo "::group::Patched pyproject.toml" diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 50d28c68ee..c3638c35eb 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,3 +1,16 @@ +# Commits in this file will be removed from GitHub blame results. +# +# To use this file locally, use: +# git blame --ignore-revs-file="path/to/.git-blame-ignore-revs" <files> +# +# or configure the `blame.ignoreRevsFile` option in your git config. +# +# If ignoring a pull request that was not squash merged, only the merge +# commit needs to be put here. Child commits will be resolved from it. + +# Run black (#3679). +8b3d9b6b199abb87246f982d5db356f1966db925 + # Black reformatting (#5482). 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4bc29c8207..c8b033e8a4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -328,9 +328,6 @@ jobs: - arrangement: monolith database: Postgres - - arrangement: workers - database: Postgres - steps: - name: Run actions/checkout@v2 for synapse uses: actions/checkout@v2 @@ -346,6 +343,30 @@ jobs: shell: bash name: Run Complement Tests + # XXX When complement with workers is stable, move this back into the standard + # "complement" matrix above. + # + # See https://github.com/matrix-org/synapse/issues/13161 + complement-workers: + if: "${{ !failure() && !cancelled() }}" + needs: linting-done + runs-on: ubuntu-latest + + steps: + - name: Run actions/checkout@v2 for synapse + uses: actions/checkout@v2 + with: + path: synapse + + - name: Prepare Complement's Prerequisites + run: synapse/.ci/scripts/setup_complement_prerequisites.sh + + - run: | + set -o pipefail + POSTGRES=1 WORKERS=1 COMPLEMENT_DIR=`pwd`/complement synapse/scripts-dev/complement.sh -json 2>&1 | gotestfmt + shell: bash + name: Run Complement Tests + # a job which marks all the other jobs as complete, thus allowing PRs to be merged. tests-done: if: ${{ always() }} diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index f35e82297f..dd8e6fbb1c 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -127,12 +127,12 @@ jobs: run: | set -x DEBIAN_FRONTEND=noninteractive sudo apt-get install -yqq python3 pipx - pipx install poetry==1.1.12 + pipx install poetry==1.1.14 poetry remove -n twisted poetry add -n --extras tls git+https://github.com/twisted/twisted.git#trunk poetry lock --no-update - # NOT IN 1.1.12 poetry lock --check + # NOT IN 1.1.14 poetry lock --check working-directory: synapse - run: | diff --git a/CHANGES.md b/CHANGES.md index 143ab4e0e5..1d123abc19 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse vNext +============= + +As of this release, Synapse no longer allows the tasks of verifying email address ownership, and password reset confirmation, to be delegated to an identity server. For more information, see the [upgrade notes](https://matrix-org.github.io/synapse/v1.64/upgrade.html#upgrading-to-v1640). + + Synapse 1.63.1 (2022-07-20) =========================== @@ -91,7 +97,6 @@ Internal Changes - More aggressively rotate push actions. ([\#13211](https://github.com/matrix-org/synapse/issues/13211)) - Add `max_line_length` setting for Python files to the `.editorconfig`. Contributed by @sumnerevans @ Beeper. ([\#13228](https://github.com/matrix-org/synapse/issues/13228)) - Synapse 1.62.0 (2022-07-05) =========================== @@ -99,7 +104,6 @@ No significant changes since 1.62.0rc3. Authors of spam-checker plugins should consult the [upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.62/docs/upgrade.md#upgrading-to-v1620) to learn about the enriched signatures for spam checker callbacks, which are supported with this release of Synapse. - Synapse 1.62.0rc3 (2022-07-04) ============================== @@ -139,7 +143,7 @@ Bugfixes - Update [MSC3786](https://github.com/matrix-org/matrix-spec-proposals/pull/3786) implementation to check `state_key`. ([\#12939](https://github.com/matrix-org/synapse/issues/12939)) - Fix a bug introduced in Synapse 1.58 where Synapse would not report full version information when installed from a git checkout. This is a best-effort affair and not guaranteed to be stable. ([\#12973](https://github.com/matrix-org/synapse/issues/12973)) - Fix a bug introduced in Synapse 1.60 where Synapse would fail to start if the `sqlite3` module was not available. ([\#12979](https://github.com/matrix-org/synapse/issues/12979)) -- Fix a bug where non-standard information was required when requesting the `/hierarchy` API over federation. Introduced +- Fix a bug where non-standard information was required when requesting the `/hierarchy` API over federation. Introduced in Synapse v1.41.0. ([\#12991](https://github.com/matrix-org/synapse/issues/12991)) - Fix a long-standing bug which meant that rate limiting was not restrictive enough in some cases. ([\#13018](https://github.com/matrix-org/synapse/issues/13018)) - Fix a bug introduced in Synapse 1.58 where profile requests for a malformed user ID would ccause an internal error. Synapse now returns 400 Bad Request in this situation. ([\#13041](https://github.com/matrix-org/synapse/issues/13041)) diff --git a/changelog.d/12942.misc b/changelog.d/12942.misc new file mode 100644 index 0000000000..acb2558d57 --- /dev/null +++ b/changelog.d/12942.misc @@ -0,0 +1 @@ +Use lower isolation level when purging rooms to avoid serialization errors. Contributed by Nick @ Beeper. diff --git a/changelog.d/12943.misc b/changelog.d/12943.misc new file mode 100644 index 0000000000..f66bb3ec32 --- /dev/null +++ b/changelog.d/12943.misc @@ -0,0 +1 @@ +Remove code which incorrectly attempted to reconcile state with remote servers when processing incoming events. diff --git a/changelog.d/12967.removal b/changelog.d/12967.removal new file mode 100644 index 0000000000..0aafd6a4d9 --- /dev/null +++ b/changelog.d/12967.removal @@ -0,0 +1 @@ +Drop tables used for groups/communities. diff --git a/changelog.d/13038.feature b/changelog.d/13038.feature new file mode 100644 index 0000000000..1278f1b4e9 --- /dev/null +++ b/changelog.d/13038.feature @@ -0,0 +1 @@ +Provide more info why we don't have any thumbnails to serve. diff --git a/changelog.d/13094.misc b/changelog.d/13094.misc new file mode 100644 index 0000000000..f1e55ae476 --- /dev/null +++ b/changelog.d/13094.misc @@ -0,0 +1 @@ +Make the AS login method call `Auth.get_user_by_req` for checking the AS token. diff --git a/changelog.d/13172.misc b/changelog.d/13172.misc new file mode 100644 index 0000000000..124a1b3662 --- /dev/null +++ b/changelog.d/13172.misc @@ -0,0 +1 @@ +Always use a version of canonicaljson that supports the C implementation of frozendict. diff --git a/changelog.d/13175.misc b/changelog.d/13175.misc new file mode 100644 index 0000000000..f273b3d6ca --- /dev/null +++ b/changelog.d/13175.misc @@ -0,0 +1 @@ +Add prometheus counters for ephemeral events and to device messages pushed to app services. Contributed by Brad @ Beeper. diff --git a/changelog.d/13192.removal b/changelog.d/13192.removal new file mode 100644 index 0000000000..a7dffd1c48 --- /dev/null +++ b/changelog.d/13192.removal @@ -0,0 +1 @@ +Drop support for delegating email verification to an external server. diff --git a/changelog.d/13198.misc b/changelog.d/13198.misc new file mode 100644 index 0000000000..5aef2432df --- /dev/null +++ b/changelog.d/13198.misc @@ -0,0 +1 @@ +Refactor receipts servlet logic to avoid duplicated code. diff --git a/changelog.d/13208.feature b/changelog.d/13208.feature new file mode 100644 index 0000000000..b0c5f090ee --- /dev/null +++ b/changelog.d/13208.feature @@ -0,0 +1 @@ +Add a `room_type` field in the responses for the list room and room details admin API. Contributed by @andrewdoh. \ No newline at end of file diff --git a/changelog.d/13215.misc b/changelog.d/13215.misc new file mode 100644 index 0000000000..3da35addb3 --- /dev/null +++ b/changelog.d/13215.misc @@ -0,0 +1 @@ +Preparation for database schema simplifications: populate `state_key` and `rejection_reason` for existing rows in the `events` table. diff --git a/changelog.d/13218.misc b/changelog.d/13218.misc new file mode 100644 index 0000000000..b1c8e5c747 --- /dev/null +++ b/changelog.d/13218.misc @@ -0,0 +1 @@ +Remove unused database table `event_reference_hashes`. diff --git a/changelog.d/13220.feature b/changelog.d/13220.feature new file mode 100644 index 0000000000..9b0240fdc8 --- /dev/null +++ b/changelog.d/13220.feature @@ -0,0 +1 @@ +Add support for room version 10. diff --git a/changelog.d/13224.misc b/changelog.d/13224.misc new file mode 100644 index 0000000000..41f8693b74 --- /dev/null +++ b/changelog.d/13224.misc @@ -0,0 +1 @@ +Further reduce queries used sending events when creating new rooms. Contributed by Nick @ Beeper (@fizzadar). diff --git a/changelog.d/13231.doc b/changelog.d/13231.doc new file mode 100644 index 0000000000..e750f9da49 --- /dev/null +++ b/changelog.d/13231.doc @@ -0,0 +1 @@ +Provide an example of using the Admin API. Contributed by @jejo86. diff --git a/changelog.d/13233.doc b/changelog.d/13233.doc new file mode 100644 index 0000000000..3eaea7c5e3 --- /dev/null +++ b/changelog.d/13233.doc @@ -0,0 +1 @@ +Move the documentation for how URL previews work to the URL preview module. diff --git a/changelog.d/13239.removal b/changelog.d/13239.removal new file mode 100644 index 0000000000..8f6045176d --- /dev/null +++ b/changelog.d/13239.removal @@ -0,0 +1 @@ +Drop support for calling `/_matrix/client/v3/account/3pid/bind` without an `id_access_token`, which was not permitted by the spec. Contributed by @Vetchu. \ No newline at end of file diff --git a/changelog.d/13240.misc b/changelog.d/13240.misc new file mode 100644 index 0000000000..0567e47d64 --- /dev/null +++ b/changelog.d/13240.misc @@ -0,0 +1 @@ +Call the v2 identity service `/3pid/unbind` endpoint, rather than v1. \ No newline at end of file diff --git a/changelog.d/13242.misc b/changelog.d/13242.misc new file mode 100644 index 0000000000..7f8ec0815f --- /dev/null +++ b/changelog.d/13242.misc @@ -0,0 +1 @@ +Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar). diff --git a/changelog.d/13251.misc b/changelog.d/13251.misc new file mode 100644 index 0000000000..526369e403 --- /dev/null +++ b/changelog.d/13251.misc @@ -0,0 +1 @@ +Optimise federation sender and appservice pusher event stream processing queries. Contributed by Nick @ Beeper (@fizzadar). diff --git a/changelog.d/13253.misc b/changelog.d/13253.misc new file mode 100644 index 0000000000..cba6b9ee0f --- /dev/null +++ b/changelog.d/13253.misc @@ -0,0 +1 @@ +Preparatory work for a per-room rate limiter on joins. diff --git a/changelog.d/13254.misc b/changelog.d/13254.misc new file mode 100644 index 0000000000..cba6b9ee0f --- /dev/null +++ b/changelog.d/13254.misc @@ -0,0 +1 @@ +Preparatory work for a per-room rate limiter on joins. diff --git a/changelog.d/13255.misc b/changelog.d/13255.misc new file mode 100644 index 0000000000..cba6b9ee0f --- /dev/null +++ b/changelog.d/13255.misc @@ -0,0 +1 @@ +Preparatory work for a per-room rate limiter on joins. diff --git a/changelog.d/13257.misc b/changelog.d/13257.misc new file mode 100644 index 0000000000..5fc1388520 --- /dev/null +++ b/changelog.d/13257.misc @@ -0,0 +1 @@ +Log the stack when waiting for an entire room to be un-partial stated. diff --git a/changelog.d/13258.misc b/changelog.d/13258.misc new file mode 100644 index 0000000000..a187c46aa6 --- /dev/null +++ b/changelog.d/13258.misc @@ -0,0 +1 @@ +Fix spurious warning when fetching state after a missing prev event. diff --git a/changelog.d/13260.misc b/changelog.d/13260.misc new file mode 100644 index 0000000000..b55ff32c76 --- /dev/null +++ b/changelog.d/13260.misc @@ -0,0 +1 @@ +Clean-up tests for notifications. diff --git a/changelog.d/13261.doc b/changelog.d/13261.doc new file mode 100644 index 0000000000..3eaea7c5e3 --- /dev/null +++ b/changelog.d/13261.doc @@ -0,0 +1 @@ +Move the documentation for how URL previews work to the URL preview module. diff --git a/changelog.d/13263.bugfix b/changelog.d/13263.bugfix new file mode 100644 index 0000000000..91e1d1e7eb --- /dev/null +++ b/changelog.d/13263.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.15.0 where adding a user through the Synapse Admin API with a phone number would fail if the "enable_email_notifs" and "email_notifs_for_new_users" options were enabled. Contributed by @thomasweston12. diff --git a/changelog.d/13266.misc b/changelog.d/13266.misc new file mode 100644 index 0000000000..d583acb81b --- /dev/null +++ b/changelog.d/13266.misc @@ -0,0 +1 @@ +Do not fail build if complement with workers fails. diff --git a/changelog.d/13267.misc b/changelog.d/13267.misc new file mode 100644 index 0000000000..a334414320 --- /dev/null +++ b/changelog.d/13267.misc @@ -0,0 +1 @@ +Don't pull out state in `compute_event_context` for unconflicted state. diff --git a/changelog.d/13270.bugfix b/changelog.d/13270.bugfix new file mode 100644 index 0000000000..d023b25eea --- /dev/null +++ b/changelog.d/13270.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.40 where a user invited to a restricted room would be briefly unable to join. diff --git a/changelog.d/13271.doc b/changelog.d/13271.doc new file mode 100644 index 0000000000..b50e60d029 --- /dev/null +++ b/changelog.d/13271.doc @@ -0,0 +1 @@ +Add another `contrib` script to help set up worker processes. Contributed by @villepeh. diff --git a/changelog.d/13274.misc b/changelog.d/13274.misc new file mode 100644 index 0000000000..a334414320 --- /dev/null +++ b/changelog.d/13274.misc @@ -0,0 +1 @@ +Don't pull out state in `compute_event_context` for unconflicted state. diff --git a/changelog.d/13276.feature b/changelog.d/13276.feature new file mode 100644 index 0000000000..068d158ed5 --- /dev/null +++ b/changelog.d/13276.feature @@ -0,0 +1 @@ +Add per-room rate limiting for room joins. For each room, Synapse now monitors the rate of join events in that room, and throttle additional joins if that rate grows too large. diff --git a/changelog.d/13278.bugfix b/changelog.d/13278.bugfix new file mode 100644 index 0000000000..49e9377c79 --- /dev/null +++ b/changelog.d/13278.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where in rare instances Synapse could store the incorrect state for a room after a state resolution. diff --git a/changelog.d/13279.misc b/changelog.d/13279.misc new file mode 100644 index 0000000000..a083d2af2a --- /dev/null +++ b/changelog.d/13279.misc @@ -0,0 +1 @@ +Reduce the rebuild time for the complement-synapse docker image. diff --git a/changelog.d/13281.misc b/changelog.d/13281.misc new file mode 100644 index 0000000000..dea51d1362 --- /dev/null +++ b/changelog.d/13281.misc @@ -0,0 +1 @@ +Don't pull out the full state when creating an event. diff --git a/changelog.d/13284.misc b/changelog.d/13284.misc new file mode 100644 index 0000000000..fa9743a10e --- /dev/null +++ b/changelog.d/13284.misc @@ -0,0 +1 @@ +Update locked version of `frozendict` to 2.3.2, which has a fix for a memory leak. diff --git a/changelog.d/13285.misc b/changelog.d/13285.misc new file mode 100644 index 0000000000..b7bcbadb5b --- /dev/null +++ b/changelog.d/13285.misc @@ -0,0 +1 @@ +Upgrade from Poetry 1.1.14 to 1.1.12, to fix bugs when locking packages. diff --git a/changelog.d/13296.bugfix b/changelog.d/13296.bugfix new file mode 100644 index 0000000000..ff0eb2b4a1 --- /dev/null +++ b/changelog.d/13296.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.18.0 where the `synapse_pushers` metric would overcount pushers when they are replaced. diff --git a/changelog.d/13297.misc b/changelog.d/13297.misc new file mode 100644 index 0000000000..545a62369f --- /dev/null +++ b/changelog.d/13297.misc @@ -0,0 +1 @@ +Use `HTTPStatus` constants in place of literals in tests. \ No newline at end of file diff --git a/changelog.d/13299.misc b/changelog.d/13299.misc new file mode 100644 index 0000000000..a9d5566873 --- /dev/null +++ b/changelog.d/13299.misc @@ -0,0 +1 @@ +Improve performance of query `_get_subset_users_in_room_with_profiles`. diff --git a/changelog.d/13300.misc b/changelog.d/13300.misc new file mode 100644 index 0000000000..ee58add3c4 --- /dev/null +++ b/changelog.d/13300.misc @@ -0,0 +1 @@ +Up batch size of `bulk_get_push_rules` and `_get_joined_profiles_from_event_ids`. diff --git a/changelog.d/13303.misc b/changelog.d/13303.misc new file mode 100644 index 0000000000..03f64ab171 --- /dev/null +++ b/changelog.d/13303.misc @@ -0,0 +1 @@ +Remove unnecessary `json.dumps` from tests. \ No newline at end of file diff --git a/changelog.d/13307.misc b/changelog.d/13307.misc new file mode 100644 index 0000000000..45b628ce13 --- /dev/null +++ b/changelog.d/13307.misc @@ -0,0 +1 @@ +Don't pull out the full state when creating an event. \ No newline at end of file diff --git a/changelog.d/13308.misc b/changelog.d/13308.misc new file mode 100644 index 0000000000..7f8ec0815f --- /dev/null +++ b/changelog.d/13308.misc @@ -0,0 +1 @@ +Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar). diff --git a/changelog.d/13310.misc b/changelog.d/13310.misc new file mode 100644 index 0000000000..eaf570e058 --- /dev/null +++ b/changelog.d/13310.misc @@ -0,0 +1 @@ +Reduce memory usage of sending dummy events. diff --git a/changelog.d/13311.misc b/changelog.d/13311.misc new file mode 100644 index 0000000000..4be81c675c --- /dev/null +++ b/changelog.d/13311.misc @@ -0,0 +1 @@ +Prevent formatting changes of [#3679](https://github.com/matrix-org/synapse/pull/3679) from appearing in `git blame`. \ No newline at end of file diff --git a/changelog.d/13314.doc b/changelog.d/13314.doc new file mode 100644 index 0000000000..75c71ef27a --- /dev/null +++ b/changelog.d/13314.doc @@ -0,0 +1 @@ +Add notes when config options where changed. Contributed by @behrmann. diff --git a/changelog.d/13318.misc b/changelog.d/13318.misc new file mode 100644 index 0000000000..f5cd26b862 --- /dev/null +++ b/changelog.d/13318.misc @@ -0,0 +1 @@ +Validate federation destinations and log an error if a destination is invalid. diff --git a/changelog.d/13320.misc b/changelog.d/13320.misc new file mode 100644 index 0000000000..d33cf3a25a --- /dev/null +++ b/changelog.d/13320.misc @@ -0,0 +1 @@ +Fix `FederationClient.get_pdu()` returning events from the cache as `outliers` instead of original events we saw over federation. diff --git a/changelog.d/13323.misc b/changelog.d/13323.misc new file mode 100644 index 0000000000..3caa94a2f6 --- /dev/null +++ b/changelog.d/13323.misc @@ -0,0 +1 @@ +Reduce memory usage of state caches. diff --git a/changelog.d/13326.removal b/changelog.d/13326.removal new file mode 100644 index 0000000000..8112286671 --- /dev/null +++ b/changelog.d/13326.removal @@ -0,0 +1 @@ +Stop builindg `.deb` packages for Ubuntu 21.10 (Impish Indri), which has reached end of life. diff --git a/changelog.d/13328.misc b/changelog.d/13328.misc new file mode 100644 index 0000000000..d15fb5fc37 --- /dev/null +++ b/changelog.d/13328.misc @@ -0,0 +1 @@ +Add type hints to `trace` decorator. diff --git a/contrib/workers-bash-scripts/create-multiple-workers.md b/contrib/workers-bash-scripts/create-multiple-generic-workers.md index ad5142fe15..d303101429 100644 --- a/contrib/workers-bash-scripts/create-multiple-workers.md +++ b/contrib/workers-bash-scripts/create-multiple-generic-workers.md @@ -1,4 +1,4 @@ -# Creating multiple workers with a bash script +# Creating multiple generic workers with a bash script Setting up multiple worker configuration files manually can be time-consuming. You can alternatively create multiple worker configuration files with a simple `bash` script. For example: diff --git a/contrib/workers-bash-scripts/create-multiple-stream-writers.md b/contrib/workers-bash-scripts/create-multiple-stream-writers.md new file mode 100644 index 0000000000..0d2ca780a6 --- /dev/null +++ b/contrib/workers-bash-scripts/create-multiple-stream-writers.md @@ -0,0 +1,145 @@ +# Creating multiple stream writers with a bash script + +This script creates multiple [stream writer](https://github.com/matrix-org/synapse/blob/develop/docs/workers.md#stream-writers) workers. + +Stream writers require both replication and HTTP listeners. + +It also prints out the example lines for Synapse main configuration file. + +Remember to route necessary endpoints directly to a worker associated with it. + +If you run the script as-is, it will create workers with the replication listener starting from port 8034 and another, regular http listener starting from 8044. If you don't need all of the stream writers listed in the script, just remove them from the ```STREAM_WRITERS``` array. + +```sh +#!/bin/bash + +# Start with these replication and http ports. +# The script loop starts with the exact port and then increments it by one. +REP_START_PORT=8034 +HTTP_START_PORT=8044 + +# Stream writer workers to generate. Feel free to add or remove them as you wish. +# Event persister ("events") isn't included here as it does not require its +# own HTTP listener. + +STREAM_WRITERS+=( "presence" "typing" "receipts" "to_device" "account_data" ) + +NUM_WRITERS=$(expr ${#STREAM_WRITERS[@]}) + +i=0 + +while [ $i -lt "$NUM_WRITERS" ] +do +cat << EOF > ${STREAM_WRITERS[$i]}_stream_writer.yaml +worker_app: synapse.app.generic_worker +worker_name: ${STREAM_WRITERS[$i]}_stream_writer + +# The replication listener on the main synapse process. +worker_replication_host: 127.0.0.1 +worker_replication_http_port: 9093 + +worker_listeners: + - type: http + port: $(expr $REP_START_PORT + $i) + resources: + - names: [replication] + + - type: http + port: $(expr $HTTP_START_PORT + $i) + resources: + - names: [client] + +worker_log_config: /etc/matrix-synapse/stream-writer-log.yaml +EOF +HOMESERVER_YAML_INSTANCE_MAP+=$" ${STREAM_WRITERS[$i]}_stream_writer: + host: 127.0.0.1 + port: $(expr $REP_START_PORT + $i) +" + +HOMESERVER_YAML_STREAM_WRITERS+=$" ${STREAM_WRITERS[$i]}: ${STREAM_WRITERS[$i]}_stream_writer +" + +((i++)) +done + +cat << EXAMPLECONFIG +# Add these lines to your homeserver.yaml. +# Don't forget to configure your reverse proxy and +# necessary endpoints to their respective worker. + +# See https://github.com/matrix-org/synapse/blob/develop/docs/workers.md +# for more information. + +# Remember: Under NO circumstances should the replication +# listener be exposed to the public internet; +# it has no authentication and is unencrypted. + +instance_map: +$HOMESERVER_YAML_INSTANCE_MAP +stream_writers: +$HOMESERVER_YAML_STREAM_WRITERS +EXAMPLECONFIG +``` + +Copy the code above save it to a file ```create_stream_writers.sh``` (for example). + +Make the script executable by running ```chmod +x create_stream_writers.sh```. + +## Run the script to create workers and print out a sample configuration + +Simply run the script to create YAML files in the current folder and print out the required configuration for ```homeserver.yaml```. + +```console +$ ./create_stream_writers.sh + +# Add these lines to your homeserver.yaml. +# Don't forget to configure your reverse proxy and +# necessary endpoints to their respective worker. + +# See https://github.com/matrix-org/synapse/blob/develop/docs/workers.md +# for more information + +# Remember: Under NO circumstances should the replication +# listener be exposed to the public internet; +# it has no authentication and is unencrypted. + +instance_map: + presence_stream_writer: + host: 127.0.0.1 + port: 8034 + typing_stream_writer: + host: 127.0.0.1 + port: 8035 + receipts_stream_writer: + host: 127.0.0.1 + port: 8036 + to_device_stream_writer: + host: 127.0.0.1 + port: 8037 + account_data_stream_writer: + host: 127.0.0.1 + port: 8038 + +stream_writers: + presence: presence_stream_writer + typing: typing_stream_writer + receipts: receipts_stream_writer + to_device: to_device_stream_writer + account_data: account_data_stream_writer +``` + +Simply copy-and-paste the output to an appropriate place in your Synapse main configuration file. + +## Write directly to Synapse configuration file + +You could also write the output directly to homeserver main configuration file. **This, however, is not recommended** as even a small typo (such as replacing >> with >) can erase the entire ```homeserver.yaml```. + +If you do this, back up your original configuration file first: + +```console +# Back up homeserver.yaml first +cp /etc/matrix-synapse/homeserver.yaml /etc/matrix-synapse/homeserver.yaml.bak + +# Create workers and write output to your homeserver.yaml +./create_stream_writers.sh >> /etc/matrix-synapse/homeserver.yaml +``` diff --git a/docker/Dockerfile b/docker/Dockerfile index 22707ed142..f4d8e6c925 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -45,7 +45,7 @@ RUN \ # We install poetry in its own build stage to avoid its dependencies conflicting with # synapse's dependencies. -# We use a specific commit from poetry's master branch instead of our usual 1.1.12, +# We use a specific commit from poetry's master branch instead of our usual 1.1.14, # to incorporate fixes to some bugs in `poetry export`. This commit corresponds to # https://github.com/python-poetry/poetry/pull/5156 and # https://github.com/python-poetry/poetry/issues/5141 ; diff --git a/docker/complement/Dockerfile b/docker/complement/Dockerfile index 8bec0f6116..c5e7984a28 100644 --- a/docker/complement/Dockerfile +++ b/docker/complement/Dockerfile @@ -4,42 +4,58 @@ # # Instructions for building this image from those it depends on is detailed in this guide: # https://github.com/matrix-org/synapse/blob/develop/docker/README-testing.md#testing-with-postgresql-and-single-or-multi-process-synapse -ARG SYNAPSE_VERSION=latest -FROM matrixdotorg/synapse-workers:$SYNAPSE_VERSION - -# Install postgresql -RUN apt-get update && \ - DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -yqq postgresql-13 - -# Configure a user and create a database for Synapse -RUN pg_ctlcluster 13 main start && su postgres -c "echo \ - \"ALTER USER postgres PASSWORD 'somesecret'; \ - CREATE DATABASE synapse \ - ENCODING 'UTF8' \ - LC_COLLATE='C' \ - LC_CTYPE='C' \ - template=template0;\" | psql" && pg_ctlcluster 13 main stop -# Extend the shared homeserver config to disable rate-limiting, -# set Complement's static shared secret, enable registration, amongst other -# tweaks to get Synapse ready for testing. -# To do this, we copy the old template out of the way and then include it -# with Jinja2. -RUN mv /conf/shared.yaml.j2 /conf/shared-orig.yaml.j2 -COPY conf/workers-shared-extra.yaml.j2 /conf/shared.yaml.j2 - -WORKDIR /data +ARG SYNAPSE_VERSION=latest -COPY conf/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf +# first of all, we create a base image with a postgres server and database, +# which we can copy into the target image. For repeated rebuilds, this is +# much faster than apt installing postgres each time. +# +# This trick only works because (a) the Synapse image happens to have all the +# shared libraries that postgres wants, (b) we use a postgres image based on +# the same debian version as Synapse's docker image (so the versions of the +# shared libraries match). -# Copy the entrypoint -COPY conf/start_for_complement.sh / +FROM postgres:13-bullseye AS postgres_base + # initialise the database cluster in /var/lib/postgresql + RUN gosu postgres initdb --locale=C --encoding=UTF-8 --auth-host password -# Expose nginx's listener ports -EXPOSE 8008 8448 + # Configure a password and create a database for Synapse + RUN echo "ALTER USER postgres PASSWORD 'somesecret'" | gosu postgres postgres --single + RUN echo "CREATE DATABASE synapse" | gosu postgres postgres --single -ENTRYPOINT ["/start_for_complement.sh"] +# now build the final image, based on the Synapse image. -# Update the healthcheck to have a shorter check interval -HEALTHCHECK --start-period=5s --interval=1s --timeout=1s \ - CMD /bin/sh /healthcheck.sh +FROM matrixdotorg/synapse-workers:$SYNAPSE_VERSION + # copy the postgres installation over from the image we built above + RUN adduser --system --uid 999 postgres --home /var/lib/postgresql + COPY --from=postgres_base /var/lib/postgresql /var/lib/postgresql + COPY --from=postgres_base /usr/lib/postgresql /usr/lib/postgresql + COPY --from=postgres_base /usr/share/postgresql /usr/share/postgresql + RUN mkdir /var/run/postgresql && chown postgres /var/run/postgresql + ENV PATH="${PATH}:/usr/lib/postgresql/13/bin" + ENV PGDATA=/var/lib/postgresql/data + + # Extend the shared homeserver config to disable rate-limiting, + # set Complement's static shared secret, enable registration, amongst other + # tweaks to get Synapse ready for testing. + # To do this, we copy the old template out of the way and then include it + # with Jinja2. + RUN mv /conf/shared.yaml.j2 /conf/shared-orig.yaml.j2 + COPY conf/workers-shared-extra.yaml.j2 /conf/shared.yaml.j2 + + WORKDIR /data + + COPY conf/postgres.supervisord.conf /etc/supervisor/conf.d/postgres.conf + + # Copy the entrypoint + COPY conf/start_for_complement.sh / + + # Expose nginx's listener ports + EXPOSE 8008 8448 + + ENTRYPOINT ["/start_for_complement.sh"] + + # Update the healthcheck to have a shorter check interval + HEALTHCHECK --start-period=5s --interval=1s --timeout=1s \ + CMD /bin/sh /healthcheck.sh diff --git a/docker/complement/conf/postgres.supervisord.conf b/docker/complement/conf/postgres.supervisord.conf index 5dae3e6330..b88bfc772e 100644 --- a/docker/complement/conf/postgres.supervisord.conf +++ b/docker/complement/conf/postgres.supervisord.conf @@ -1,5 +1,5 @@ [program:postgres] -command=/usr/local/bin/prefix-log /usr/bin/pg_ctlcluster 13 main start --foreground +command=/usr/local/bin/prefix-log gosu postgres postgres # Only start if START_POSTGRES=1 autostart=%(ENV_START_POSTGRES)s diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index b5f675bc73..9e554a865e 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -67,6 +67,10 @@ rc_joins: per_second: 9999 burst_count: 9999 +rc_joins_per_room: + per_second: 9999 + burst_count: 9999 + rc_3pid_validation: per_second: 1000 burst_count: 1000 diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 2e8c4e2137..2d56b084e2 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -35,7 +35,6 @@ - [Application Services](application_services.md) - [Server Notices](server_notices.md) - [Consent Tracking](consent_tracking.md) - - [URL Previews](development/url_previews.md) - [User Directory](user_directory.md) - [Message Retention Policies](message_retention_policies.md) - [Pluggable Modules](modules/index.md) diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index d4873f9490..9aa489e4a3 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -59,6 +59,7 @@ The following fields are possible in the JSON response body: - `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"]. - `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. - `state_events` - Total number of state_events of a room. Complexity of the room. + - `room_type` - The type of the room taken from the room's creation event; for example "m.space" if the room is a space. If the room does not define a type, the value will be `null`. * `offset` - The current pagination offset in rooms. This parameter should be used instead of `next_token` for room offset as `next_token` is not intended to be parsed. @@ -101,7 +102,8 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 93534 + "state_events": 93534, + "room_type": "m.space" }, ... (8 hidden items) ... { @@ -118,7 +120,8 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 8345 + "state_events": 8345, + "room_type": null } ], "offset": 0, @@ -151,7 +154,8 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 8 + "state_events": 8, + "room_type": null } ], "offset": 0, @@ -184,7 +188,8 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 93534 + "state_events": 93534, + "room_type": null }, ... (98 hidden items) ... { @@ -201,7 +206,8 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 8345 + "state_events": 8345, + "room_type": "m.space" } ], "offset": 0, @@ -238,7 +244,9 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 93534 + "state_events": 93534, + "room_type": "m.space" + }, ... (48 hidden items) ... { @@ -255,7 +263,9 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 8345 + "state_events": 8345, + "room_type": null + } ], "offset": 100, @@ -290,6 +300,8 @@ The following fields are possible in the JSON response body: * `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"]. * `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. * `state_events` - Total number of state_events of a room. Complexity of the room. +* `room_type` - The type of the room taken from the room's creation event; for example "m.space" if the room is a space. + If the room does not define a type, the value will be `null`. The API is: @@ -317,7 +329,8 @@ A response body like the following is returned: "join_rules": "invite", "guest_access": null, "history_visibility": "shared", - "state_events": 93534 + "state_events": 93534, + "room_type": "m.space" } ``` diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 1235f1cb95..0871cfebf5 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -544,7 +544,7 @@ Gets a list of all local media that a specific `user_id` has created. These are media that the user has uploaded themselves ([local media](../media_repository.md#local-media)), as well as [URL preview images](../media_repository.md#url-previews) requested by the user if the -[feature is enabled](../development/url_previews.md). +[feature is enabled](../usage/configuration/config_documentation.md#url_preview_enabled). By default, the response is ordered by descending creation date and ascending media ID. The newest media is on top. You can change the order with parameters diff --git a/docs/development/dependencies.md b/docs/development/dependencies.md index 8ef7d357d8..236856a6b0 100644 --- a/docs/development/dependencies.md +++ b/docs/development/dependencies.md @@ -237,3 +237,28 @@ poetry run pip install build && poetry run python -m build because [`build`](https://github.com/pypa/build) is a standardish tool which doesn't require poetry. (It's what we use in CI too). However, you could try `poetry build` too. + + +# Troubleshooting + +## Check the version of poetry with `poetry --version`. + +At the time of writing, the 1.2 series is beta only. We have seen some examples +where the lockfiles generated by 1.2 prereleasese aren't interpreted correctly +by poetry 1.1.x. For now, use poetry 1.1.14, which includes a critical +[change](https://github.com/python-poetry/poetry/pull/5973) needed to remain +[compatible with PyPI](https://github.com/pypi/warehouse/pull/11775). + +It can also be useful to check the version of `poetry-core` in use. If you've +installed `poetry` with `pipx`, try `pipx runpip poetry list | grep poetry-core`. + +## Clear caches: `poetry cache clear --all pypi`. + +Poetry caches a bunch of information about packages that isn't readily available +from PyPI. (This is what makes poetry seem slow when doing the first +`poetry install`.) Try `poetry cache list` and `poetry cache clear --all +<name of cache>` to see if that fixes things. + +## Try `--verbose` or `--dry-run` arguments. + +Sometimes useful to see what poetry's internal logic is. diff --git a/docs/development/url_previews.md b/docs/development/url_previews.md deleted file mode 100644 index 154b9a5e12..0000000000 --- a/docs/development/url_previews.md +++ /dev/null @@ -1,61 +0,0 @@ -URL Previews -============ - -The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API -for URLs which outputs [Open Graph](https://ogp.me/) responses (with some Matrix -specific additions). - -This does have trade-offs compared to other designs: - -* Pros: - * Simple and flexible; can be used by any clients at any point -* Cons: - * If each homeserver provides one of these independently, all the HSes in a - room may needlessly DoS the target URI - * The URL metadata must be stored somewhere, rather than just using Matrix - itself to store the media. - * Matrix cannot be used to distribute the metadata between homeservers. - -When Synapse is asked to preview a URL it does the following: - -1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the - config). -2. Checks the in-memory cache by URLs and returns the result if it exists. (This - is also used to de-duplicate processing of multiple in-flight requests at once.) -3. Kicks off a background process to generate a preview: - 1. Checks the database cache by URL and timestamp and returns the result if it - has not expired and was successful (a 2xx return code). - 2. Checks if the URL matches an [oEmbed](https://oembed.com/) pattern. If it - does, update the URL to download. - 3. Downloads the URL and stores it into a file via the media storage provider - and saves the local media metadata. - 4. If the media is an image: - 1. Generates thumbnails. - 2. Generates an Open Graph response based on image properties. - 5. If the media is HTML: - 1. Decodes the HTML via the stored file. - 2. Generates an Open Graph response from the HTML. - 3. If a JSON oEmbed URL was found in the HTML via autodiscovery: - 1. Downloads the URL and stores it into a file via the media storage provider - and saves the local media metadata. - 2. Convert the oEmbed response to an Open Graph response. - 3. Override any Open Graph data from the HTML with data from oEmbed. - 4. If an image exists in the Open Graph response: - 1. Downloads the URL and stores it into a file via the media storage - provider and saves the local media metadata. - 2. Generates thumbnails. - 3. Updates the Open Graph response based on image properties. - 6. If the media is JSON and an oEmbed URL was found: - 1. Convert the oEmbed response to an Open Graph response. - 2. If a thumbnail or image is in the oEmbed response: - 1. Downloads the URL and stores it into a file via the media storage - provider and saves the local media metadata. - 2. Generates thumbnails. - 3. Updates the Open Graph response based on image properties. - 7. Stores the result in the database cache. -4. Returns the result. - -The in-memory cache expires after 1 hour. - -Expired entries in the database cache (and their associated media files) are -deleted every 10 seconds. The default expiration time is 1 hour from download. diff --git a/docs/media_repository.md b/docs/media_repository.md index ba17f8a856..23e6da7f31 100644 --- a/docs/media_repository.md +++ b/docs/media_repository.md @@ -7,8 +7,7 @@ The media repository users. * caches avatars, attachments and their thumbnails for media uploaded by remote users. - * caches resources and thumbnails used for - [URL previews](development/url_previews.md). + * caches resources and thumbnails used for URL previews. All media in Matrix can be identified by a unique [MXC URI](https://spec.matrix.org/latest/client-server-api/#matrix-content-mxc-uris), @@ -59,8 +58,6 @@ remote_thumbnail/matrix.org/aa/bb/cccccccccccccccccccc/128-96-image-jpeg Note that `remote_thumbnail/` does not have an `s`. ## URL Previews -See [URL Previews](development/url_previews.md) for documentation on the URL preview -process. When generating previews for URLs, Synapse may download and cache various resources, including images. These resources are assigned temporary media IDs diff --git a/docs/upgrade.md b/docs/upgrade.md index 312f0b87fe..2c7c258909 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -89,6 +89,31 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.64.0 + +## Delegation of email validation no longer supported + +As of this version, Synapse no longer allows the tasks of verifying email address +ownership, and password reset confirmation, to be delegated to an identity server. + +To continue to allow users to add email addresses to their homeserver accounts, +and perform password resets, make sure that Synapse is configured with a +working email server in the `email` configuration section (including, at a +minimum, a `notif_from` setting.) + +Specifying an `email` setting under `account_threepid_delegates` will now cause +an error at startup. + +## Changes to the event replication streams + +Synapse now includes a flag indicating if an event is an outlier when +replicating it to other workers. This is a forwards- and backwards-incompatible +change: v1.63 and workers cannot process events replicated by v1.64 workers, and +vice versa. + +Once all workers are upgraded to v1.64 (or downgraded to v1.63), event +replication will resume as normal. + # Upgrading to v1.62.0 ## New signatures for spam checker callbacks diff --git a/docs/usage/administration/admin_api/README.md b/docs/usage/administration/admin_api/README.md index 3cbedc5dfa..c60b6da0de 100644 --- a/docs/usage/administration/admin_api/README.md +++ b/docs/usage/administration/admin_api/README.md @@ -18,6 +18,11 @@ already on your `$PATH` depending on how Synapse was installed. Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings. ## Making an Admin API request +For security reasons, we [recommend](reverse_proxy.md#synapse-administration-endpoints) +that the Admin API (`/_synapse/admin/...`) should be hidden from public view using a +reverse proxy. This means you should typically query the Admin API from a terminal on +the machine which runs Synapse. + Once you have your `access_token`, you will need to authenticate each request to an Admin API endpoint by providing the token as either a query parameter or a request header. To add it as a request header in cURL: @@ -25,5 +30,17 @@ providing the token as either a query parameter or a request header. To add it a curl --header "Authorization: Bearer <access_token>" <the_rest_of_your_API_request> ``` +For example, suppose we want to +[query the account](user_admin_api.md#query-user-account) of the user +`@foo:bar.com`. We need an admin access token (e.g. +`syt_AjfVef2_L33JNpafeif_0feKJfeaf0CQpoZk`), and we need to know which port +Synapse's [`client` listener](config_documentation.md#listeners) is listening +on (e.g. `8008`). Then we can use the following command to request the account +information from the Admin API. + +```sh +curl --header "Authorization: Bearer syt_AjfVef2_L33JNpafeif_0feKJfeaf0CQpoZk" -X GET http://127.0.0.1:8008/_synapse/admin/v2/users/@foo:bar.com +``` + For more details on access tokens in Matrix, please refer to the complete [matrix spec documentation](https://matrix.org/docs/spec/client_server/r0.6.1#using-access-tokens). diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 601fdeb09e..11d1574484 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -239,6 +239,8 @@ If this option is provided, it parses the given yaml to json and serves it on `/.well-known/matrix/client` endpoint alongside the standard properties. +*Added in Synapse 1.62.0.* + Example configuration: ```yaml extra_well_known_client_content : @@ -1155,6 +1157,9 @@ Caching can be configured through the following sub-options: with intermittent connections, at the cost of higher memory usage. A value of zero means that sync responses are not cached. Defaults to 2m. + + *Changed in Synapse 1.62.0*: The default was changed from 0 to 2m. + * `cache_autotuning` and its sub-options `max_cache_memory_usage`, `target_cache_memory_usage`, and `min_cache_ttl` work in conjunction with each other to maintain a balance between cache memory usage and cache entry availability. You must be using [jemalloc](https://github.com/matrix-org/synapse#help-synapse-is-slow-and-eats-all-my-ramcpu) @@ -1472,6 +1477,25 @@ rc_joins: burst_count: 12 ``` --- +### `rc_joins_per_room` + +This option allows admins to ratelimit joins to a room based on the number of recent +joins (local or remote) to that room. It is intended to mitigate mass-join spam +waves which target multiple homeservers. + +By default, one join is permitted to a room every second, with an accumulating +buffer of up to ten instantaneous joins. + +Example configuration (default values): +```yaml +rc_joins_per_room: + per_second: 1 + burst_count: 10 +``` + +_Added in Synapse 1.64.0._ + +--- ### `rc_3pid_validation` This option ratelimits how often a user or IP can attempt to validate a 3PID. @@ -2176,30 +2200,26 @@ default_identity_server: https://matrix.org --- ### `account_threepid_delegates` -Handle threepid (email/phone etc) registration and password resets through a set of -*trusted* identity servers. Note that this allows the configured identity server to -reset passwords for accounts! +Delegate verification of phone numbers to an identity server. -Be aware that if `email` is not set, and SMTP options have not been -configured in the email config block, registration and user password resets via -email will be globally disabled. +When a user wishes to add a phone number to their account, we need to verify that they +actually own that phone number, which requires sending them a text message (SMS). +Currently Synapse does not support sending those texts itself and instead delegates the +task to an identity server. The base URI for the identity server to be used is +specified by the `account_threepid_delegates.msisdn` option. -Additionally, if `msisdn` is not set, registration and password resets via msisdn -will be disabled regardless, and users will not be able to associate an msisdn -identifier to their account. This is due to Synapse currently not supporting -any method of sending SMS messages on its own. +If this is left unspecified, Synapse will not allow users to add phone numbers to +their account. -To enable using an identity server for operations regarding a particular third-party -identifier type, set the value to the URL of that identity server as shown in the -examples below. +(Servers handling the these requests must answer the `/requestToken` endpoints defined +by the Matrix Identity Service API +[specification](https://matrix.org/docs/spec/identity_service/latest).) -Servers handling the these requests must answer the `/requestToken` endpoints defined -by the Matrix Identity Service API [specification](https://matrix.org/docs/spec/identity_service/latest). +*Updated in Synapse 1.64.0*: No longer accepts an `email` option. Example configuration: ```yaml account_threepid_delegates: - email: https://example.com # Delegate email sending to example.com msisdn: http://localhost:8090 # Delegate SMS sending to this local process ``` --- diff --git a/poetry.lock b/poetry.lock index b7c0a6869a..41ab40edd1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -290,7 +290,7 @@ importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} [[package]] name = "frozendict" -version = "2.3.0" +version = "2.3.2" description = "A simple immutable dictionary" category = "main" optional = false @@ -1563,7 +1563,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "e96625923122e29b6ea5964379828e321b6cede2b020fc32c6f86c09d86d1ae8" +content-hash = "c24bbcee7e86dbbe7cdbf49f91a25b310bf21095452641e7440129f59b077f78" [metadata.files] attrs = [ @@ -1753,23 +1753,23 @@ flake8-comprehensions = [ {file = "flake8_comprehensions-3.8.0-py3-none-any.whl", hash = "sha256:9406314803abe1193c064544ab14fdc43c58424c0882f6ff8a581eb73fc9bb58"}, ] frozendict = [ - {file = "frozendict-2.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e18e2abd144a9433b0a8334582843b2aa0d3b9ac8b209aaa912ad365115fe2e1"}, - {file = "frozendict-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96dc7a02e78da5725e5e642269bb7ae792e0c9f13f10f2e02689175ebbfedb35"}, - {file = "frozendict-2.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:752a6dcfaf9bb20a7ecab24980e4dbe041f154509c989207caf185522ef85461"}, - {file = "frozendict-2.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:5346d9fc1c936c76d33975a9a9f1a067342963105d9a403a99e787c939cc2bb2"}, - {file = "frozendict-2.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60dd2253f1bacb63a7c486ec541a968af4f985ffb06602ee8954a3d39ec6bd2e"}, - {file = "frozendict-2.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:b2e044602ce17e5cd86724add46660fb9d80169545164e763300a3b839cb1b79"}, - {file = "frozendict-2.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a27a69b1ac3591e4258325108aee62b53c0eeb6ad0a993ae68d3c7eaea980420"}, - {file = "frozendict-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f45ef5f6b184d84744fff97b61f6b9a855e24d36b713ea2352fc723a047afa5"}, - {file = "frozendict-2.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2d3f5016650c0e9a192f5024e68fb4d63f670d0ee58b099ed3f5b4c62ea30ecb"}, - {file = "frozendict-2.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6cf605916f50aabaaba5624c81eb270200f6c2c466c46960237a125ec8fe3ae0"}, - {file = "frozendict-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6da06e44904beae4412199d7e49be4f85c6cc168ab06b77c735ea7da5ce3454"}, - {file = "frozendict-2.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:1f34793fb409c4fa70ffd25bea87b01f3bd305fb1c6b09e7dff085b126302206"}, - {file = "frozendict-2.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fd72494a559bdcd28aa71f4aa81860269cd0b7c45fff3e2614a0a053ecfd2a13"}, - {file = "frozendict-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00ea9166aa68cc5feed05986206fdbf35e838a09cb3feef998cf35978ff8a803"}, - {file = "frozendict-2.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:9ffaf440648b44e0bc694c1a4701801941378ba3ba6541e17750ae4b4aeeb116"}, - {file = "frozendict-2.3.0-py3-none-any.whl", hash = "sha256:8578fe06815fcdcc672bd5603eebc98361a5317c1c3a13b28c6c810f6ea3b323"}, - {file = "frozendict-2.3.0.tar.gz", hash = "sha256:da4231adefc5928e7810da2732269d3ad7b5616295b3e693746392a8205ea0b5"}, + {file = "frozendict-2.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4fb171d1e84d17335365877e19d17440373b47ca74a73c06f65ac0b16d01e87f"}, + {file = "frozendict-2.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a3640e9d7533d164160b758351aa49d9e85bbe0bd76d219d4021e90ffa6a52"}, + {file = "frozendict-2.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:87cfd00fafbc147d8cd2590d1109b7db8ac8d7d5bdaa708ba46caee132b55d4d"}, + {file = "frozendict-2.3.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:fb09761e093cfabb2f179dbfdb2521e1ec5701df714d1eb5c51fa7849027be19"}, + {file = "frozendict-2.3.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82176dc7adf01cf8f0193e909401939415a230a1853f4a672ec1629a06ceae18"}, + {file = "frozendict-2.3.2-cp36-cp36m-win_amd64.whl", hash = "sha256:c1c70826aa4a50fa283fe161834ac4a3ac7c753902c980bb8b595b0998a38ddb"}, + {file = "frozendict-2.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1db5035ddbed995badd1a62c4102b5e207b5aeb24472df2c60aba79639d7996b"}, + {file = "frozendict-2.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4246fc4cb1413645ba4d3513939b90d979a5bae724be605a10b2b26ee12f839c"}, + {file = "frozendict-2.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:680cd42fb0a255da1ce45678ccbd7f69da750d5243809524ebe8f45b2eda6e6b"}, + {file = "frozendict-2.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6a7f3a181d6722c92a9fab12d0c5c2b006a18ca5666098531f316d1e1c8984e3"}, + {file = "frozendict-2.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1cb866eabb3c1384a7fe88e1e1033e2b6623073589012ab637c552bf03f6364"}, + {file = "frozendict-2.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:952c5e5e664578c5c2ce8489ee0ab6a1855da02b58ef593ee728fc10d672641a"}, + {file = "frozendict-2.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:608b77904cd0117cd816df605a80d0043a5326ee62529327d2136c792165a823"}, + {file = "frozendict-2.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0eed41fd326f0bcc779837d8d9e1374da1bc9857fe3b9f2910195bbd5fff3aeb"}, + {file = "frozendict-2.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:bde28db6b5868dd3c45b3555f9d1dc5a1cca6d93591502fa5dcecce0dde6a335"}, + {file = "frozendict-2.3.2-py3-none-any.whl", hash = "sha256:6882a9bbe08ab9b5ff96ce11bdff3fe40b114b9813bc6801261e2a7b45e20012"}, + {file = "frozendict-2.3.2.tar.gz", hash = "sha256:7fac4542f0a13fbe704db4942f41ba3abffec5af8b100025973e59dff6a09d0d"}, ] gitdb = [ {file = "gitdb-4.0.9-py3-none-any.whl", hash = "sha256:8033ad4e853066ba6ca92050b9df2f89301b8fc8bf7e9324d412a63f8bf1a8fd"}, diff --git a/pyproject.toml b/pyproject.toml index 9eabe15e23..4da1331c93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,9 @@ jsonschema = ">=3.0.0" frozendict = ">=1,!=2.1.2" # We require 2.1.0 or higher for type hints. Previous guard was >= 1.1.0 unpaddedbase64 = ">=2.1.0" -canonicaljson = "^1.4.0" +# We require 1.5.0 to work around an issue when running against the C implementation of +# frozendict: https://github.com/matrix-org/python-canonicaljson/issues/36 +canonicaljson = "^1.5.0" # we use the type definitions added in signedjson 1.1. signedjson = "^1.1.0" # validating SSL certs for IP addresses requires service_identity 18.1. diff --git a/scripts-dev/build_debian_packages.py b/scripts-dev/build_debian_packages.py index 38564893e9..cd2e64b75f 100755 --- a/scripts-dev/build_debian_packages.py +++ b/scripts-dev/build_debian_packages.py @@ -26,7 +26,6 @@ DISTS = ( "debian:bookworm", "debian:sid", "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) - "ubuntu:impish", # 21.10 (EOL 2022-07) "ubuntu:jammy", # 22.04 LTS (EOL 2027-04) ) diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 26834a437e..543bba27c2 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -166,22 +166,6 @@ IGNORED_TABLES = { "ui_auth_sessions", "ui_auth_sessions_credentials", "ui_auth_sessions_ips", - # Groups/communities is no longer supported. - "group_attestations_remote", - "group_attestations_renewals", - "group_invites", - "group_roles", - "group_room_categories", - "group_rooms", - "group_summary_roles", - "group_summary_room_categories", - "group_summary_rooms", - "group_summary_users", - "group_users", - "groups", - "local_group_membership", - "local_group_updates", - "remote_profile_cache", } diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 54d13026c9..f43965c1c8 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -27,6 +27,33 @@ class Ratelimiter: """ Ratelimit actions marked by arbitrary keys. + (Note that the source code speaks of "actions" and "burst_count" rather than + "tokens" and a "bucket_size".) + + This is a "leaky bucket as a meter". For each key to be tracked there is a bucket + containing some number 0 <= T <= `burst_count` of tokens corresponding to previously + permitted requests for that key. Each bucket starts empty, and gradually leaks + tokens at a rate of `rate_hz`. + + Upon an incoming request, we must determine: + - the key that this request falls under (which bucket to inspect), and + - the cost C of this request in tokens. + Then, if there is room in the bucket for C tokens (T + C <= `burst_count`), + the request is permitted and `cost` tokens are added to the bucket. + Otherwise the request is denied, and the bucket continues to hold T tokens. + + This means that the limiter enforces an average request frequency of `rate_hz`, + while accumulating a buffer of up to `burst_count` requests which can be consumed + instantaneously. + + The tricky bit is the leaking. We do not want to have a periodic process which + leaks every bucket! Instead, we track + - the time point when the bucket was last completely empty, and + - how many tokens have added to the bucket permitted since then. + Then for each incoming request, we can calculate how many tokens have leaked + since this time point, and use that to decide if we should accept or reject the + request. + Args: clock: A homeserver clock, for retrieving the current time rate_hz: The long term number of actions that can be performed in a second. @@ -41,14 +68,30 @@ class Ratelimiter: self.burst_count = burst_count self.store = store - # A ordered dictionary keeping track of actions, when they were last - # performed and how often. Each entry is a mapping from a key of arbitrary type - # to a tuple representing: - # * How many times an action has occurred since a point in time - # * The point in time - # * The rate_hz of this particular entry. This can vary per request + # An ordered dictionary representing the token buckets tracked by this rate + # limiter. Each entry maps a key of arbitrary type to a tuple representing: + # * The number of tokens currently in the bucket, + # * The time point when the bucket was last completely empty, and + # * The rate_hz (leak rate) of this particular bucket. self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() + def _get_key( + self, requester: Optional[Requester], key: Optional[Hashable] + ) -> Hashable: + """Use the requester's MXID as a fallback key if no key is provided.""" + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + return key + + def _get_action_counts( + self, key: Hashable, time_now_s: float + ) -> Tuple[float, float, float]: + """Retrieve the action counts, with a fallback representing an empty bucket.""" + return self.actions.get(key, (0.0, time_now_s, 0.0)) + async def can_do_action( self, requester: Optional[Requester], @@ -88,11 +131,7 @@ class Ratelimiter: * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ - if key is None: - if not requester: - raise ValueError("Must supply at least one of `requester` or `key`") - - key = requester.user.to_string() + key = self._get_key(requester, key) if requester: # Disable rate limiting of users belonging to any AS that is configured @@ -121,7 +160,7 @@ class Ratelimiter: self._prune_message_counts(time_now_s) # Check if there is an existing count entry for this key - action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + action_count, time_start, _ = self._get_action_counts(key, time_now_s) # Check whether performing another action is allowed time_delta = time_now_s - time_start @@ -164,6 +203,37 @@ class Ratelimiter: return allowed, time_allowed + def record_action( + self, + requester: Optional[Requester], + key: Optional[Hashable] = None, + n_actions: int = 1, + _time_now_s: Optional[float] = None, + ) -> None: + """Record that an action(s) took place, even if they violate the rate limit. + + This is useful for tracking the frequency of events that happen across + federation which we still want to impose local rate limits on. For instance, if + we are alice.com monitoring a particular room, we cannot prevent bob.com + from joining users to that room. However, we can track the number of recent + joins in the room and refuse to serve new joins ourselves if there have been too + many in the room across both homeservers. + + Args: + requester: The requester that is doing the action, if any. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. + n_actions: The number of times the user wants to do this action. If the user + cannot do all of the actions, the user's action count is not incremented + at all. + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + """ + key = self._get_key(requester, key) + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s) + self.actions[key] = (action_count + n_actions, time_start, rate_hz) + def _prune_message_counts(self, time_now_s: float) -> None: """Remove message count entries that have not exceeded their defined rate_hz limit diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 3f85d61b46..00e81b3afc 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -84,6 +84,8 @@ class RoomVersion: # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of # knocks and restricted join rules into the same join condition. msc3787_knock_restricted_join_rule: bool + # MSC3667: Enforce integer power levels + msc3667_int_only_power_levels: bool class RoomVersions: @@ -103,6 +105,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V2 = RoomVersion( "2", @@ -120,6 +123,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V3 = RoomVersion( "3", @@ -137,6 +141,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V4 = RoomVersion( "4", @@ -154,6 +159,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V5 = RoomVersion( "5", @@ -171,6 +177,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V6 = RoomVersion( "6", @@ -188,6 +195,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -205,6 +213,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V7 = RoomVersion( "7", @@ -222,6 +231,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V8 = RoomVersion( "8", @@ -239,6 +249,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) V9 = RoomVersion( "9", @@ -256,6 +267,7 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) MSC2716v3 = RoomVersion( "org.matrix.msc2716v3", @@ -273,6 +285,7 @@ class RoomVersions: msc2716_historical=True, msc2716_redactions=True, msc3787_knock_restricted_join_rule=False, + msc3667_int_only_power_levels=False, ) MSC3787 = RoomVersion( "org.matrix.msc3787", @@ -290,6 +303,25 @@ class RoomVersions: msc2716_historical=False, msc2716_redactions=False, msc3787_knock_restricted_join_rule=True, + msc3667_int_only_power_levels=False, + ) + V10 = RoomVersion( + "10", + RoomDisposition.STABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, + msc3375_redaction_rules=True, + msc2403_knocking=True, + msc2716_historical=False, + msc2716_redactions=False, + msc3787_knock_restricted_join_rule=True, + msc3667_int_only_power_levels=True, ) @@ -308,6 +340,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.V9, RoomVersions.MSC2716v3, RoomVersions.MSC3787, + RoomVersions.V10, ) } diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 745e704141..6bafa7d3f3 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -44,7 +44,6 @@ from synapse.app._base import ( register_start, ) from synapse.config._base import ConfigError, format_config_error -from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig from synapse.config.server import ListenerConfig from synapse.federation.transport.server import TransportLayerServer @@ -202,7 +201,7 @@ class SynapseHomeServer(HomeServer): } ) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index df1c214462..0963fb3bb4 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -53,6 +53,18 @@ sent_events_counter = Counter( "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"] ) +sent_ephemeral_counter = Counter( + "synapse_appservice_api_sent_ephemeral", + "Number of ephemeral events sent to the AS", + ["service"], +) + +sent_todevice_counter = Counter( + "synapse_appservice_api_sent_todevice", + "Number of todevice messages sent to the AS", + ["service"], +) + HOUR_IN_MS = 60 * 60 * 1000 @@ -310,6 +322,8 @@ class ApplicationServiceApi(SimpleHttpClient): ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(serialized_events)) + sent_ephemeral_counter.labels(service.id).inc(len(ephemeral)) + sent_todevice_counter.labels(service.id).inc(len(to_device_messages)) return True except CodeMessageException as e: logger.warning( diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 6e11fbdb9a..3ead80d985 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -18,7 +18,6 @@ import email.utils import logging import os -from enum import Enum from typing import Any import attr @@ -131,41 +130,22 @@ class EmailConfig(Config): self.email_enable_notifs = email_config.get("enable_notifs", False) - self.threepid_behaviour_email = ( - # Have Synapse handle the email sending if account_threepid_delegates.email - # is not defined - # msisdn is currently always remote while Synapse does not support any method of - # sending SMS messages - ThreepidBehaviour.REMOTE - if self.root.registration.account_threepid_delegate_email - else ThreepidBehaviour.LOCAL - ) - if config.get("trust_identity_server_for_password_resets"): raise ConfigError( 'The config option "trust_identity_server_for_password_resets" ' - 'has been replaced by "account_threepid_delegate". ' - "Please consult the configuration manual at docs/usage/configuration/config_documentation.md for " - "details and update your config file." + "is no longer supported. Please remove it from the config file." ) - self.local_threepid_handling_disabled_due_to_email_config = False - if ( - self.threepid_behaviour_email == ThreepidBehaviour.LOCAL - and email_config == {} - ): - # We cannot warn the user this has happened here - # Instead do so when a user attempts to reset their password - self.local_threepid_handling_disabled_due_to_email_config = True - - self.threepid_behaviour_email = ThreepidBehaviour.OFF + # If we have email config settings, assume that we can verify ownership of + # email addresses. + self.can_verify_email = email_config != {} # Get lifetime of a validation token in milliseconds self.email_validation_token_lifetime = self.parse_duration( email_config.get("validation_token_lifetime", "1h") ) - if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.can_verify_email: missing = [] if not self.email_notif_from: missing.append("email.notif_from") @@ -356,18 +336,3 @@ class EmailConfig(Config): "Config option email.invite_client_location must be a http or https URL", path=("email", "invite_client_location"), ) - - -class ThreepidBehaviour(Enum): - """ - Enum to define the behaviour of Synapse with regards to when it contacts an identity - server for 3pid registration and password resets - - REMOTE = use an external server to send tokens - LOCAL = send tokens ourselves - OFF = disable registration via 3pid and password resets - """ - - REMOTE = "remote" - LOCAL = "local" - OFF = "off" diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 4fc1784efe..5a91917b4a 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -112,6 +112,13 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 10}, ) + # Track the rate of joins to a given room. If there are too many, temporarily + # prevent local joins and remote joins via this server. + self.rc_joins_per_room = RateLimitConfig( + config.get("rc_joins_per_room", {}), + defaults={"per_second": 1, "burst_count": 10}, + ) + # Ratelimit cross-user key requests: # * For local requests this is keyed by the sending device. # * For requests received over federation this is keyed by the origin. diff --git a/synapse/config/registration.py b/synapse/config/registration.py index fcf99be092..685a0423c5 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -20,6 +20,13 @@ from synapse.config._base import Config, ConfigError from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool +NO_EMAIL_DELEGATE_ERROR = """\ +Delegation of email verification to an identity server is no longer supported. To +continue to allow users to add email addresses to their accounts, and use them for +password resets, configure Synapse with an SMTP server via the `email` setting, and +remove `account_threepid_delegates.email`. +""" + class RegistrationConfig(Config): section = "registration" @@ -51,7 +58,9 @@ class RegistrationConfig(Config): self.bcrypt_rounds = config.get("bcrypt_rounds", 12) account_threepid_delegates = config.get("account_threepid_delegates") or {} - self.account_threepid_delegate_email = account_threepid_delegates.get("email") + if "email" in account_threepid_delegates: + raise ConfigError(NO_EMAIL_DELEGATE_ERROR) + # self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 3c69dd325f..1033496bb4 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -42,6 +42,18 @@ THUMBNAIL_SIZE_YAML = """\ # method: %(method)s """ +# A map from the given media type to the type of thumbnail we should generate +# for it. +THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = { + "image/jpeg": "jpeg", + "image/jpg": "jpeg", + "image/webp": "jpeg", + # Thumbnails can only be jpeg or png. We choose png thumbnails for gif + # because it can have transparency. + "image/gif": "png", + "image/png": "png", +} + HTTP_PROXY_SET_WARNING = """\ The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" @@ -79,13 +91,22 @@ def parse_thumbnail_requirements( width = size["width"] height = size["height"] method = size["method"] - jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg") - png_thumbnail = ThumbnailRequirement(width, height, method, "image/png") - requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail) - requirements.setdefault("image/jpg", []).append(jpeg_thumbnail) - requirements.setdefault("image/webp", []).append(jpeg_thumbnail) - requirements.setdefault("image/gif", []).append(png_thumbnail) - requirements.setdefault("image/png", []).append(png_thumbnail) + + for format, thumbnail_format in THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.items(): + requirement = requirements.setdefault(format, []) + if thumbnail_format == "jpeg": + requirement.append( + ThumbnailRequirement(width, height, method, "image/jpeg") + ) + elif thumbnail_format == "png": + requirement.append( + ThumbnailRequirement(width, height, method, "image/png") + ) + else: + raise Exception( + "Unknown thumbnail mapping from %s to %s. This is a Synapse problem, please report!" + % (format, thumbnail_format) + ) return { media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items() } diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 0fc2c4b27e..965cb265da 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -740,6 +740,32 @@ def _check_power_levels( except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) + # Reject events with stringy power levels if required by room version + if ( + event.type == EventTypes.PowerLevels + and room_version_obj.msc3667_int_only_power_levels + ): + for k, v in event.content.items(): + if k in { + "users_default", + "events_default", + "state_default", + "ban", + "redact", + "kick", + "invite", + }: + if not isinstance(v, int): + raise SynapseError(400, f"{v!r} must be an integer.") + if k in {"events", "notifications", "users"}: + if not isinstance(v, dict) or not all( + isinstance(v, int) for v in v.values() + ): + raise SynapseError( + 400, + f"{v!r} must be a dict wherein all the values are integers.", + ) + key = (event.type, event.state_key) current_state = auth_events.get(key) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 98c203ada0..17f624b68f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -24,9 +24,11 @@ from synapse.api.room_versions import ( RoomVersion, ) from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.event_auth import auth_types_for_event from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.state import StateHandler from synapse.storage.databases.main import DataStore +from synapse.storage.state import StateFilter from synapse.types import EventID, JsonDict from synapse.util import Clock from synapse.util.stringutils import random_string @@ -120,8 +122,12 @@ class EventBuilder: The signed and hashed event. """ if auth_event_ids is None: - state_ids = await self._state.get_current_state_ids( - self.room_id, prev_event_ids + state_ids = await self._state.compute_state_after_events( + self.room_id, + prev_event_ids, + state_filter=StateFilter.from_types( + auth_types_for_event(self.room_version, self) + ), ) auth_event_ids = self._event_auth_handler.compute_auth_events( self, state_ids diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 66e6305562..842f5327c2 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -53,7 +53,7 @@ from synapse.api.room_versions import ( RoomVersion, RoomVersions, ) -from synapse.events import EventBase, builder +from synapse.events import EventBase, builder, make_event_from_dict from synapse.federation.federation_base import ( FederationBase, InvalidEventSignatureError, @@ -217,7 +217,7 @@ class FederationClient(FederationBase): ) async def claim_client_keys( - self, destination: str, content: JsonDict, timeout: int + self, destination: str, content: JsonDict, timeout: Optional[int] ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. @@ -299,7 +299,8 @@ class FederationClient(FederationBase): moving to the next destination. None indicates no timeout. Returns: - The requested PDU, or None if we were unable to find it. + A copy of the requested PDU that is safe to modify, or None if we + were unable to find it. Raises: SynapseError, NotRetryingDestination, FederationDeniedError @@ -309,7 +310,7 @@ class FederationClient(FederationBase): ) logger.debug( - "retrieved event id %s from %s: %r", + "get_pdu_from_destination_raw: retrieved event id %s from %s: %r", event_id, destination, transaction_data, @@ -358,54 +359,92 @@ class FederationClient(FederationBase): The requested PDU, or None if we were unable to find it. """ + logger.debug( + "get_pdu: event_id=%s from destinations=%s", event_id, destinations + ) + # TODO: Rate limit the number of times we try and get the same event. - ev = self._get_pdu_cache.get(event_id) - if ev: - return ev + # We might need the same event multiple times in quick succession (before + # it gets persisted to the database), so we cache the results of the lookup. + # Note that this is separate to the regular get_event cache which caches + # events once they have been persisted. + event = self._get_pdu_cache.get(event_id) + + # If we don't see the event in the cache, go try to fetch it from the + # provided remote federated destinations + if not event: + pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + + for destination in destinations: + now = self._clock.time_msec() + last_attempt = pdu_attempts.get(destination, 0) + if last_attempt + PDU_RETRY_TIME_MS > now: + logger.debug( + "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)", + destination, + last_attempt, + PDU_RETRY_TIME_MS, + now, + ) + continue + + try: + event = await self.get_pdu_from_destination_raw( + destination=destination, + event_id=event_id, + room_version=room_version, + timeout=timeout, + ) - pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) + pdu_attempts[destination] = now - signed_pdu = None - for destination in destinations: - now = self._clock.time_msec() - last_attempt = pdu_attempts.get(destination, 0) - if last_attempt + PDU_RETRY_TIME_MS > now: - continue + if event: + # Prime the cache + self._get_pdu_cache[event.event_id] = event - try: - signed_pdu = await self.get_pdu_from_destination_raw( - destination=destination, - event_id=event_id, - room_version=room_version, - timeout=timeout, - ) + # FIXME: We should add a `break` here to avoid calling every + # destination after we already found a PDU (will follow-up + # in a separate PR) - pdu_attempts[destination] = now - - except SynapseError as e: - logger.info( - "Failed to get PDU %s from %s because %s", event_id, destination, e - ) - continue - except NotRetryingDestination as e: - logger.info(str(e)) - continue - except FederationDeniedError as e: - logger.info(str(e)) - continue - except Exception as e: - pdu_attempts[destination] = now + except SynapseError as e: + logger.info( + "Failed to get PDU %s from %s because %s", + event_id, + destination, + e, + ) + continue + except NotRetryingDestination as e: + logger.info(str(e)) + continue + except FederationDeniedError as e: + logger.info(str(e)) + continue + except Exception as e: + pdu_attempts[destination] = now + + logger.info( + "Failed to get PDU %s from %s because %s", + event_id, + destination, + e, + ) + continue - logger.info( - "Failed to get PDU %s from %s because %s", event_id, destination, e - ) - continue + if not event: + return None - if signed_pdu: - self._get_pdu_cache[event_id] = signed_pdu + # `event` now refers to an object stored in `get_pdu_cache`. Our + # callers may need to modify the returned object (eg to set + # `event.internal_metadata.outlier = true`), so we return a copy + # rather than the original object. + event_copy = make_event_from_dict( + event.get_pdu_json(), + event.room_version, + ) - return signed_pdu + return event_copy async def get_room_state_ids( self, destination: str, room_id: str, event_id: str diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5dfdc86740..ae550d3f4d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -118,6 +118,7 @@ class FederationServer(FederationBase): self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() + self._room_member_handler = hs.get_room_member_handler() self._state_storage_controller = hs.get_storage_controllers().state @@ -621,6 +622,15 @@ class FederationServer(FederationBase): ) raise IncompatibleRoomVersionError(room_version=room_version) + # Refuse the request if that room has seen too many joins recently. + # This is in addition to the HS-level rate limiting applied by + # BaseFederationServlet. + # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?) + await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type] + requester=None, + key=room_id, + update=False, + ) pdu = await self.handler.on_make_join_request(origin, room_id, user_id) return {"event": pdu.get_templated_pdu_json(), "room_version": room_version} @@ -655,6 +665,12 @@ class FederationServer(FederationBase): room_id: str, caller_supports_partial_state: bool = False, ) -> Dict[str, Any]: + await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type] + requester=None, + key=room_id, + update=False, + ) + event, context = await self._on_send_membership_event( origin, content, Membership.JOIN, room_id ) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 99a794c042..94a65ac65f 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -351,7 +351,11 @@ class FederationSender(AbstractFederationSender): self._is_processing = True while True: last_token = await self.store.get_federation_out_pos("events") - next_token, events = await self.store.get_all_new_events_stream( + ( + next_token, + events, + event_to_received_ts, + ) = await self.store.get_all_new_events_stream( last_token, self._last_poked_id, limit=100 ) @@ -476,7 +480,7 @@ class FederationSender(AbstractFederationSender): await self._send_pdu(event, sharded_destinations) now = self.clock.time_msec() - ts = await self.store.get_received_ts(event.event_id) + ts = event_to_received_ts[event.event_id] assert ts is not None synapse.metrics.event_processing_lag_by_event.labels( "federation_sender" @@ -509,7 +513,7 @@ class FederationSender(AbstractFederationSender): if events: now = self.clock.time_msec() - ts = await self.store.get_received_ts(events[-1].event_id) + ts = event_to_received_ts[events[-1].event_id] assert ts is not None synapse.metrics.event_processing_lag.labels( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9e84bd677e..32074b8ca6 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -619,7 +619,7 @@ class TransportLayerClient: ) async def claim_client_keys( - self, destination: str, query_content: JsonDict, timeout: int + self, destination: str, query_content: JsonDict, timeout: Optional[int] ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 814553e098..203b62e015 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -104,14 +104,15 @@ class ApplicationServicesHandler: with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: - limit = 100 upper_bound = -1 while upper_bound < self.current_max: + last_token = await self.store.get_appservice_last_pos() ( upper_bound, events, - ) = await self.store.get_new_events_for_appservice( - self.current_max, limit + event_to_received_ts, + ) = await self.store.get_all_new_events_stream( + last_token, self.current_max, limit=100, get_prev_content=True ) events_by_room: Dict[str, List[EventBase]] = {} @@ -150,7 +151,7 @@ class ApplicationServicesHandler: ) now = self.clock.time_msec() - ts = await self.store.get_received_ts(event.event_id) + ts = event_to_received_ts[event.event_id] assert ts is not None synapse.metrics.event_processing_lag_by_event.labels( @@ -187,7 +188,7 @@ class ApplicationServicesHandler: if events: now = self.clock.time_msec() - ts = await self.store.get_received_ts(events[-1].event_id) + ts = event_to_received_ts[events[-1].event_id] assert ts is not None synapse.metrics.event_processing_lag.labels( diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 52bb5c9c55..84c28c480e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -92,7 +92,11 @@ class E2eKeysHandler: @trace async def query_devices( - self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str + self, + query_body: JsonDict, + timeout: int, + from_user_id: str, + from_device_id: Optional[str], ) -> JsonDict: """Handle a device key query from a client @@ -120,9 +124,7 @@ class E2eKeysHandler: the number of in-flight queries at a time. """ async with self._query_devices_linearizer.queue((from_user_id, from_device_id)): - device_keys_query: Dict[str, Iterable[str]] = query_body.get( - "device_keys", {} - ) + device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {}) # separate users by domain. # make a map from domain to user_id to device_ids @@ -392,7 +394,7 @@ class E2eKeysHandler: @trace async def query_local_devices( - self, query: Dict[str, Optional[List[str]]] + self, query: Mapping[str, Optional[List[str]]] ) -> Dict[str, Dict[str, dict]]: """Get E2E device keys for local users @@ -461,7 +463,7 @@ class E2eKeysHandler: @trace async def claim_one_time_keys( - self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] ) -> JsonDict: local_query: List[Tuple[str, str, str]] = [] remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index c74117c19a..a5f4ce7c8a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import itertools import logging from http import HTTPStatus @@ -347,7 +348,7 @@ class FederationEventHandler: event.internal_metadata.send_on_behalf_of = origin context = await self._state_handler.compute_event_context(event) - context = await self._check_event_auth(origin, event, context) + await self._check_event_auth(origin, event, context) if context.rejected: raise SynapseError( 403, f"{event.membership} event was rejected", Codes.FORBIDDEN @@ -485,7 +486,7 @@ class FederationEventHandler: partial_state=partial_state, ) - context = await self._check_event_auth(origin, event, context) + await self._check_event_auth(origin, event, context) if context.rejected: raise SynapseError(400, "Join event was rejected") @@ -765,10 +766,24 @@ class FederationEventHandler: """ logger.info("Processing pulled event %s", event) - # these should not be outliers. - assert ( - not event.internal_metadata.is_outlier() - ), "pulled event unexpectedly flagged as outlier" + # This function should not be used to persist outliers (use something + # else) because this does a bunch of operations that aren't necessary + # (extra work; in particular, it makes sure we have all the prev_events + # and resolves the state across those prev events). If you happen to run + # into a situation where the event you're trying to process/backfill is + # marked as an `outlier`, then you should update that spot to return an + # `EventBase` copy that doesn't have `outlier` flag set. + # + # `EventBase` is used to represent both an event we have not yet + # persisted, and one that we have persisted and now keep in the cache. + # In an ideal world this method would only be called with the first type + # of event, but it turns out that's not actually the case and for + # example, you could get an event from cache that is marked as an + # `outlier` (fix up that spot though). + assert not event.internal_metadata.is_outlier(), ( + "Outlier event passed to _process_pulled_event. " + "To persist an event as a non-outlier, make sure to pass in a copy without `event.internal_metadata.outlier = true`." + ) event_id = event.event_id @@ -1036,6 +1051,9 @@ class FederationEventHandler: # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. failed_to_fetch = desired_events - event_metadata.keys() + # `event_id` could be missing from `event_metadata` because it's not necessarily + # a state event. We've already checked that we've fetched it above. + failed_to_fetch.discard(event_id) if failed_to_fetch: logger.warning( "Failed to fetch missing state events for %s %s", @@ -1116,11 +1134,7 @@ class FederationEventHandler: state_ids_before_event=state_ids, ) try: - context = await self._check_event_auth( - origin, - event, - context, - ) + await self._check_event_auth(origin, event, context) except AuthError as e: # This happens only if we couldn't find the auth events. We'll already have # logged a warning, so now we just convert to a FederationError. @@ -1495,11 +1509,8 @@ class FederationEventHandler: ) async def _check_event_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - ) -> EventContext: + self, origin: str, event: EventBase, context: EventContext + ) -> None: """ Checks whether an event should be rejected (for failing auth checks). @@ -1509,9 +1520,6 @@ class FederationEventHandler: context: The event context. - Returns: - The updated context object. - Raises: AuthError if we were unable to find copies of the event's auth events. (Most other failures just cause us to set `context.rejected`.) @@ -1526,7 +1534,7 @@ class FederationEventHandler: logger.warning("While validating received event %r: %s", event, e) # TODO: use a different rejected reason here? context.rejected = RejectedReason.AUTH_ERROR - return context + return # next, check that we have all of the event's auth events. # @@ -1538,6 +1546,9 @@ class FederationEventHandler: ) # ... and check that the event passes auth at those auth events. + # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu: + # 4. Passes authorization rules based on the event’s auth events, + # otherwise it is rejected. try: await check_state_independent_auth_rules(self._store, event) check_state_dependent_auth_rules(event, claimed_auth_events) @@ -1546,55 +1557,90 @@ class FederationEventHandler: "While checking auth of %r against auth_events: %s", event, e ) context.rejected = RejectedReason.AUTH_ERROR - return context + return - # now check auth against what we think the auth events *should* be. - event_types = event_auth.auth_types_for_event(event.room_version, event) - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types(event_types) + # now check the auth rules pass against the room state before the event + # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu: + # 5. Passes authorization rules based on the state before the event, + # otherwise it is rejected. + # + # ... however, if we only have partial state for the room, then there is a good + # chance that we'll be missing some of the state needed to auth the new event. + # So, we state-resolve the auth events that we are given against the state that + # we know about, which ensures things like bans are applied. (Note that we'll + # already have checked we have all the auth events, in + # _load_or_fetch_auth_events_for_event above) + if context.partial_state: + room_version = await self._store.get_room_version_id(event.room_id) + + local_state_id_map = await context.get_prev_state_ids() + claimed_auth_events_id_map = { + (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events + } + + state_for_auth_id_map = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + [local_state_id_map, claimed_auth_events_id_map], + event_map=None, + state_res_store=StateResolutionStore(self._store), + ) + ) + else: + event_types = event_auth.auth_types_for_event(event.room_version, event) + state_for_auth_id_map = await context.get_prev_state_ids( + StateFilter.from_types(event_types) + ) + + calculated_auth_event_ids = self._event_auth_handler.compute_auth_events( + event, state_for_auth_id_map, for_verification=True ) - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True + # if those are the same, we're done here. + if collections.Counter(event.auth_event_ids()) == collections.Counter( + calculated_auth_event_ids + ): + return + + # otherwise, re-run the auth checks based on what we calculated. + calculated_auth_events = await self._store.get_events_as_list( + calculated_auth_event_ids ) - auth_events_x = await self._store.get_events(auth_events_ids) + + # log the differences + + claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events} calculated_auth_event_map = { - (e.type, e.state_key): e for e in auth_events_x.values() + (e.type, e.state_key): e for e in calculated_auth_events } + logger.info( + "event's auth_events are different to our calculated auth_events. " + "Claimed but not calculated: %s. Calculated but not claimed: %s", + [ + ev + for k, ev in claimed_auth_event_map.items() + if k not in calculated_auth_event_map + or calculated_auth_event_map[k].event_id != ev.event_id + ], + [ + ev + for k, ev in calculated_auth_event_map.items() + if k not in claimed_auth_event_map + or claimed_auth_event_map[k].event_id != ev.event_id + ], + ) try: - updated_auth_events = await self._update_auth_events_for_auth( + check_state_dependent_auth_rules(event, calculated_auth_events) + except AuthError as e: + logger.warning( + "While checking auth of %r against room state before the event: %s", event, - calculated_auth_event_map=calculated_auth_event_map, - ) - except Exception: - # We don't really mind if the above fails, so lets not fail - # processing if it does. However, it really shouldn't fail so - # let's still log as an exception since we'll still want to fix - # any bugs. - logger.exception( - "Failed to double check auth events for %s with remote. " - "Ignoring failure and continuing processing of event.", - event.event_id, - ) - updated_auth_events = None - - if updated_auth_events: - context = await self._update_context_for_auth_events( - event, context, updated_auth_events + e, ) - auth_events_for_auth = updated_auth_events - else: - auth_events_for_auth = calculated_auth_event_map - - try: - check_state_dependent_auth_rules(event, auth_events_for_auth.values()) - except AuthError as e: - logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR - return context - async def _maybe_kick_guest_users(self, event: EventBase) -> None: if event.type != EventTypes.GuestAccess: return @@ -1704,93 +1750,6 @@ class FederationEventHandler: soft_failed_event_counter.inc() event.internal_metadata.soft_failed = True - async def _update_auth_events_for_auth( - self, - event: EventBase, - calculated_auth_event_map: StateMap[EventBase], - ) -> Optional[StateMap[EventBase]]: - """Helper for _check_event_auth. See there for docs. - - Checks whether a given event has the expected auth events. If it - doesn't then we talk to the remote server to compare state to see if - we can come to a consensus (e.g. if one server missed some valid - state). - - This attempts to resolve any potential divergence of state between - servers, but is not essential and so failures should not block further - processing of the event. - - Args: - event: - - calculated_auth_event_map: - Our calculated auth_events based on the state of the room - at the event's position in the DAG. - - Returns: - updated auth event map, or None if no changes are needed. - - """ - assert not event.internal_metadata.outlier - - # check for events which are in the event's claimed auth_events, but not - # in our calculated event map. - event_auth_events = set(event.auth_event_ids()) - different_auth = event_auth_events.difference( - e.event_id for e in calculated_auth_event_map.values() - ) - - if not different_auth: - return None - - logger.info( - "auth_events refers to events which are not in our calculated auth " - "chain: %s", - different_auth, - ) - - # XXX: currently this checks for redactions but I'm not convinced that is - # necessary? - different_events = await self._store.get_events_as_list(different_auth) - - # double-check they're all in the same room - we should already have checked - # this but it doesn't hurt to check again. - for d in different_events: - assert ( - d.room_id == event.room_id - ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room" - - # now we state-resolve between our own idea of the auth events, and the remote's - # idea of them. - - local_state = calculated_auth_event_map.values() - remote_auth_events = dict(calculated_auth_event_map) - remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) - remote_state = remote_auth_events.values() - - room_version = await self._store.get_room_version_id(event.room_id) - new_state = await self._state_handler.resolve_events( - room_version, (local_state, remote_state), event - ) - different_state = { - (d.type, d.state_key): d - for d in new_state.values() - if calculated_auth_event_map.get((d.type, d.state_key)) != d - } - if not different_state: - logger.info("State res returned no new state") - return None - - logger.info( - "After state res: updating auth_events with new state %s", - different_state.values(), - ) - - # take a copy of calculated_auth_event_map before we modify it. - auth_events = dict(calculated_auth_event_map) - auth_events.update(different_state) - return auth_events - async def _load_or_fetch_auth_events_for_event( self, destination: str, event: EventBase ) -> Collection[EventBase]: @@ -1888,61 +1847,6 @@ class FederationEventHandler: await self._auth_and_persist_outliers(room_id, remote_auth_events) - async def _update_context_for_auth_events( - self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] - ) -> EventContext: - """Update the state_ids in an event context after auth event resolution, - storing the changes as a new state group. - - Args: - event: The event we're handling the context for - - context: initial event context - - auth_events: Events to update in the event context. - - Returns: - new event context - """ - # exclude the state key of the new event from the current_state in the context. - if event.is_state(): - event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) - else: - event_key = None - state_updates = { - k: a.event_id for k, a in auth_events.items() if k != event_key - } - - current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) # type: ignore - - current_state_ids.update(state_updates) - - prev_state_ids = await context.get_prev_state_ids() - prev_state_ids = dict(prev_state_ids) - - prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) - - # create a new state group as a delta from the existing one. - prev_group = context.state_group - state_group = await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=state_updates, - current_state_ids=current_state_ids, - ) - - return EventContext.with_state( - storage=self._storage_controllers, - state_group=state_group, - state_group_before_event=context.state_group_before_event, - state_delta_due_to_event=state_updates, - prev_group=prev_group, - delta_ids=state_updates, - partial_state=context.partial_state, - ) - async def _run_push_actions_and_persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False ) -> None: @@ -2093,6 +1997,10 @@ class FederationEventHandler: event, event_pos, max_stream_token, extra_users=extra_users ) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + # TODO retrieve the previous state, and exclude join -> join transitions + self._notifier.notify_user_joined_room(event.event_id, event.room_id) + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9bca2bc4b2..9571d461c8 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -26,7 +26,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient from synapse.http.site import SynapseRequest @@ -163,8 +162,7 @@ class IdentityHandler: sid: str, mxid: str, id_server: str, - id_access_token: Optional[str] = None, - use_v2: bool = True, + id_access_token: str, ) -> JsonDict: """Bind a 3PID to an identity server @@ -174,8 +172,7 @@ class IdentityHandler: mxid: The MXID to bind the 3PID to id_server: The domain of the identity server to query id_access_token: The access token to authenticate to the identity - server with, if necessary. Required if use_v2 is true - use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True + server with Raises: SynapseError: On any of the following conditions @@ -187,24 +184,15 @@ class IdentityHandler: """ logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server) - # If an id_access_token is not supplied, force usage of v1 - if id_access_token is None: - use_v2 = False - if not valid_id_server_location(id_server): raise SynapseError( 400, "id_server must be a valid hostname with optional port and path components", ) - # Decide which API endpoint URLs to use - headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} - if use_v2: - bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) - headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore - else: - bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) + bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) + headers = {"Authorization": create_id_access_token_header(id_access_token)} try: # Use the blacklisting http client as this call is only to identity servers @@ -223,21 +211,14 @@ class IdentityHandler: return data except HttpResponseException as e: - if e.code != 404 or not use_v2: - logger.error("3PID bind failed with Matrix error: %r", e) - raise e.to_synapse_error() + logger.error("3PID bind failed with Matrix error: %r", e) + raise e.to_synapse_error() except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: data = json_decoder.decode(e.msg) # XXX WAT? return data - logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) - res = await self.bind_threepid( - client_secret, sid, mxid, id_server, id_access_token, use_v2=False - ) - return res - async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool: """Attempt to remove a 3PID from an identity server, or if one is not provided, all identity servers we're aware the binding is present on @@ -300,8 +281,8 @@ class IdentityHandler: "id_server must be a valid hostname with optional port and path components", ) - url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = b"/_matrix/identity/api/v1/3pid/unbind" + url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,) + url_bytes = b"/_matrix/identity/v2/3pid/unbind" content = { "mxid": mxid, @@ -434,48 +415,6 @@ class IdentityHandler: return session_id - async def requestEmailToken( - self, - id_server: str, - email: str, - client_secret: str, - send_attempt: int, - next_link: Optional[str] = None, - ) -> JsonDict: - """ - Request an external server send an email on our behalf for the purposes of threepid - validation. - - Args: - id_server: The identity server to proxy to - email: The email to send the message to - client_secret: The unique client_secret sends by the user - send_attempt: Which attempt this is - next_link: A link to redirect the user to once they submit the token - - Returns: - The json response body from the server - """ - params = { - "email": email, - "client_secret": client_secret, - "send_attempt": send_attempt, - } - if next_link: - params["next_link"] = next_link - - try: - data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/email/requestToken", - params, - ) - return data - except HttpResponseException as e: - logger.info("Proxied requestToken failed: %r", e) - raise e.to_synapse_error() - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - async def requestMsisdnToken( self, id_server: str, @@ -549,18 +488,7 @@ class IdentityHandler: validation_session = None # Try to validate as email - if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - # Remote emails will only be used if a valid identity server is provided. - assert ( - self.hs.config.registration.account_threepid_delegate_email is not None - ) - - # Ask our delegated email identity server - validation_session = await self.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_email, - threepid_creds, - ) - elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.hs.config.email.can_verify_email: # Get a validated session matching these details validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1980e37dae..bd7baef051 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -463,6 +463,7 @@ class EventCreationHandler: ) self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._notifier = hs.get_notifier() self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state @@ -1444,7 +1445,12 @@ class EventCreationHandler: if state_entry.state_group in self._external_cache_joined_hosts_updates: return - joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + state = await state_entry.get_state( + self._storage_controllers.state, StateFilter.all() + ) + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) # Note that the expiry times must be larger than the expiry time in # _external_cache_joined_hosts_updates. @@ -1545,6 +1551,16 @@ class EventCreationHandler: requester, is_admin_redaction=is_admin_redaction ) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + event.state_key, event.room_id + ) + if current_membership != Membership.JOIN: + self._notifier.notify_user_joined_room(event.event_id, event.room_id) + await self._maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: @@ -1844,13 +1860,8 @@ class EventCreationHandler: # For each room we need to find a joined member we can use to send # the dummy event with. - latest_event_ids = await self.store.get_prev_events_for_room(room_id) - members = await self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids - ) + members = await self.store.get_local_users_in_room(room_id) for user_id in members: - if not self.hs.is_mine_id(user_id): - continue requester = create_requester(user_id, authenticated_entity=self.server_name) try: event, context = await self.create_event( @@ -1861,7 +1872,6 @@ class EventCreationHandler: "room_id": room_id, "sender": user_id, }, - prev_event_ids=latest_event_ids, ) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a54f163c0a..978d3ee39f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -889,7 +889,11 @@ class RoomCreationHandler: # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - last_stream_id = await self._send_events_for_new_room( + ( + last_stream_id, + last_sent_event_id, + depth, + ) = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -905,7 +909,7 @@ class RoomCreationHandler: if "name" in config: name = config["name"] ( - _, + name_event, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, @@ -917,12 +921,16 @@ class RoomCreationHandler: "content": {"name": name}, }, ratelimit=False, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = name_event.event_id + depth += 1 if "topic" in config: topic = config["topic"] ( - _, + topic_event, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, @@ -934,7 +942,11 @@ class RoomCreationHandler: "content": {"topic": topic}, }, ratelimit=False, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = topic_event.event_id + depth += 1 # we avoid dropping the lock between invites, as otherwise joins can # start coming in and making the createRoom slow. @@ -949,7 +961,7 @@ class RoomCreationHandler: for invitee in invite_list: ( - _, + member_event_id, last_stream_id, ) = await self.room_member_handler.update_membership_locked( requester, @@ -959,7 +971,11 @@ class RoomCreationHandler: ratelimit=False, content=content, new_room=True, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = member_event_id + depth += 1 for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] @@ -968,7 +984,10 @@ class RoomCreationHandler: medium = invite_3pid["medium"] # Note that do_3pid_invite can raise a ShadowBanError, but this was # handled above by emptying invite_3pid_list. - last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( + ( + member_event_id, + last_stream_id, + ) = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -977,7 +996,11 @@ class RoomCreationHandler: requester, txn_id=None, id_access_token=id_access_token, + prev_event_ids=[last_sent_event_id], + depth=depth, ) + last_sent_event_id = member_event_id + depth += 1 result = {"room_id": room_id} @@ -1005,20 +1028,22 @@ class RoomCreationHandler: power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, ratelimit: bool = True, - ) -> int: + ) -> Tuple[int, str, int]: """Sends the initial events into a new room. `power_level_content_override` doesn't apply when initial state has power level state event content. Returns: - The stream_id of the last event persisted. + A tuple containing the stream ID, event ID and depth of the last + event sent to the room. """ creator_id = creator.user.to_string() event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} + depth = 1 last_sent_event_id: Optional[str] = None def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: @@ -1031,6 +1056,7 @@ class RoomCreationHandler: async def send(etype: str, content: JsonDict, **kwargs: Any) -> int: nonlocal last_sent_event_id + nonlocal depth event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) @@ -1047,9 +1073,11 @@ class RoomCreationHandler: # Note: we don't pass state_event_ids here because this triggers # an additional query per event to look them up from the events table. prev_event_ids=[last_sent_event_id] if last_sent_event_id else [], + depth=depth, ) last_sent_event_id = sent_event.event_id + depth += 1 return last_stream_id @@ -1075,6 +1103,7 @@ class RoomCreationHandler: content=creator_join_profile, new_room=True, prev_event_ids=[last_sent_event_id], + depth=depth, ) last_sent_event_id = member_event_id @@ -1168,7 +1197,7 @@ class RoomCreationHandler: content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, ) - return last_sent_stream_id + return last_sent_stream_id, last_sent_event_id, depth def _generate_room_id(self) -> str: """Generates a random room ID. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 04c44b2ccb..30b4cb23df 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) + # Tracks joins from local users to rooms this server isn't a member of. + # I.e. joins this server makes by requesting /make_join /send_join from + # another server. self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + # TODO: find a better place to keep this Ratelimiter. + # It needs to be + # - written to by event persistence code + # - written to by something which can snoop on replication streams + # - read by the RoomMemberHandler to rate limit joins from local users + # - read by the FederationServer to rate limit make_joins and send_joins from + # other homeservers + # I wonder if a homeserver-wide collection of rate limiters might be cleaner? + self._join_rate_per_room_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, + ) # Ratelimiter for invites, keyed by room (across all issuers, all # recipients). @@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) self.request_ratelimiter = hs.get_request_ratelimiter() + hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room) + + def _on_user_joined_room(self, event_id: str, room_id: str) -> None: + """Notify the rate limiter that a room join has occurred. + + Use this to inform the RoomMemberHandler about joins that have either + - taken place on another homeserver, or + - on another worker in this homeserver. + Joins actioned by this worker should use the usual `ratelimit` method, which + checks the limit and increments the counter in one go. + """ + self._join_rate_per_room_limiter.record_action(requester=None, key=room_id) @abc.abstractmethod async def _remote_join( @@ -285,6 +314,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, @@ -315,6 +345,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_events are set so we need to set them ourself via this argument. This should normally be left as None, which will cause the auth_event_ids to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. txn_id: ratelimit: @@ -370,6 +403,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, + depth=depth, require_consent=require_consent, outlier=outlier, historical=historical, @@ -391,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # up blocking profile updates. if newly_joined and ratelimit: await self._join_rate_limiter_local.ratelimit(requester) + await self._join_rate_per_room_limiter.ratelimit( + requester, key=room_id, update=False + ) result_event = await self.event_creation_handler.handle_new_client_event( requester, @@ -466,6 +503,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -501,6 +539,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_events are set so we need to set them ourself via this argument. This should normally be left as None, which will cause the auth_event_ids to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. Returns: A tuple of the new event ID and stream ID. @@ -540,6 +581,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, + depth=depth, ) return result @@ -562,6 +604,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, ) -> Tuple[str, int]: """Helper for update_membership. @@ -599,6 +642,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): prev_events are set so we need to set them ourself via this argument. This should normally be left as None, which will cause the auth_event_ids to be calculated based on the room state at the prev_events. + depth: Override the depth used to order the event in the DAG. + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. Returns: A tuple of the new event ID and stream ID. @@ -732,6 +778,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, state_event_ids=state_event_ids, + depth=depth, content=content, require_consent=require_consent, outlier=outlier, @@ -740,14 +787,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = await self.state_handler.get_current_state_ids( - room_id, latest_event_ids=latest_event_ids + state_before_join = await self.state_handler.compute_state_after_events( + room_id, latest_event_ids ) # TODO: Refactor into dictionary of explicitly allowed transitions # between old and new state, with specific error messages for some # transitions and generic otherwise - old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + old_state_id = state_before_join.get((EventTypes.Member, target.to_string())) if old_state_id: old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None @@ -798,11 +845,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = await self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(state_before_join) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = await self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(state_before_join) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -840,13 +887,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Check if a remote join should be performed. remote_join, remote_room_hosts = await self._should_perform_remote_join( - target.to_string(), room_id, remote_room_hosts, content, is_host_in_room + target.to_string(), + room_id, + remote_room_hosts, + content, + is_host_in_room, + state_before_join, ) if remote_join: if ratelimit: await self._join_rate_limiter_remote.ratelimit( requester, ) + await self._join_rate_per_room_limiter.ratelimit( + requester, + key=room_id, + update=False, + ) inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): @@ -967,6 +1024,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit=ratelimit, prev_event_ids=latest_event_ids, state_event_ids=state_event_ids, + depth=depth, content=content, require_consent=require_consent, outlier=outlier, @@ -979,6 +1037,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): remote_room_hosts: List[str], content: JsonDict, is_host_in_room: bool, + state_before_join: StateMap[str], ) -> Tuple[bool, List[str]]: """ Check whether the server should do a remote join (as opposed to a local @@ -998,6 +1057,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: The content to use as the event body of the join. This may be modified. is_host_in_room: True if the host is in the room. + state_before_join: The state before the join event (i.e. the resolution of + the states after its parent events). Returns: A tuple of: @@ -1014,20 +1075,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) - current_state_ids = await self._storage_controllers.state.get_current_state_ids( - room_id - ) # If restricted join rules are not being used, a local join can always # be used. if not await self.event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version + state_before_join, room_version ): return False, [] # If the user is invited to the room or already joined, the join # event can always be issued locally. - prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None) prev_member_event = None if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) @@ -1042,10 +1100,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # # If not, generate a new list of remote hosts based on which # can issue invites. - event_map = await self.store.get_events(current_state_ids.values()) + event_map = await self.store.get_events(state_before_join.values()) current_state = { state_key: event_map[event_id] - for state_key, event_id in current_state_ids.items() + for state_key, event_id in state_before_join.items() } allowed_servers = get_servers_from_users( get_users_which_can_issue_invite(current_state) @@ -1059,7 +1117,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Ensure the member should be allowed access via membership in a room. await self.event_auth_handler.check_restricted_join_rules( - current_state_ids, room_version, user_id, prev_member_event + state_before_join, room_version, user_id, prev_member_event ) # If this is going to be a local join, additional information must @@ -1069,7 +1127,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): EventContentFields.AUTHORISING_USER ] = await self.event_auth_handler.get_user_which_could_invite( room_id, - current_state_ids, + state_before_join, ) return False, [] @@ -1322,7 +1380,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> int: + prev_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, + ) -> Tuple[str, int]: """Invite a 3PID to a room. Args: @@ -1335,9 +1395,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): txn_id: The transaction ID this is part of, or None if this is not part of a transaction. id_access_token: The optional identity server access token. + depth: Override the depth used to order the event in the DAG. + prev_event_ids: The event IDs to use as the prev events + Should normally be set to None, which will cause the depth to be calculated + based on the prev_events. Returns: - The new stream ID. + Tuple of event ID and stream ordering position Raises: ShadowBanError if the requester has been shadow-banned. @@ -1383,7 +1447,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # We don't check the invite against the spamchecker(s) here (through # user_may_invite) because we'll do it further down the line anyway (in # update_membership_locked). - _, stream_id = await self.update_membership( + event_id, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: @@ -1402,7 +1466,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): additional_fields=spam_check[1], ) - stream_id = await self._make_and_store_3pid_invite( + event, stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -1411,9 +1475,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): inviter, txn_id=txn_id, id_access_token=id_access_token, + prev_event_ids=prev_event_ids, + depth=depth, ) + event_id = event.event_id - return stream_id + return event_id, stream_id async def _make_and_store_3pid_invite( self, @@ -1425,7 +1492,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> int: + prev_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, + ) -> Tuple[EventBase, int]: room_state = await self._storage_controllers.state.get_current_state( room_id, StateFilter.from_types( @@ -1518,8 +1587,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): }, ratelimit=False, txn_id=txn_id, + prev_event_ids=prev_event_ids, + depth=depth, ) - return stream_id + return event, stream_id async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: # Have we just created the room, and is this about to be the very diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 05cebb5d4d..a744d68c64 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError -from synapse.config.emailconfig import ThreepidBehaviour from synapse.util import json_decoder if TYPE_CHECKING: @@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker: logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) - # msisdns are currently always ThreepidBehaviour.REMOTE + # msisdns are currently always verified via the IS if medium == "msisdn": if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( @@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker: threepid_creds, ) elif medium == "email": - if ( - self.hs.config.email.threepid_behaviour_email - == ThreepidBehaviour.REMOTE - ): - assert self.hs.config.registration.account_threepid_delegate_email - threepid = await identity_handler.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_email, - threepid_creds, - ) - elif ( - self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): + if self.hs.config.email.can_verify_email: threepid = None row = await self.store.get_threepid_validation_session( medium, @@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return self.hs.config.email.threepid_behaviour_email in ( - ThreepidBehaviour.REMOTE, - ThreepidBehaviour.LOCAL, - ) + return self.hs.config.email.can_verify_email async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("email", authdict) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c63d068f74..3c35b1d2c7 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -79,6 +79,7 @@ from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred from synapse.util.metrics import Measure +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.server import HomeServer @@ -479,6 +480,14 @@ class MatrixFederationHttpClient: RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. """ + # Validate server name and log if it is an invalid destination, this is + # partially to help track down code paths where we haven't validated before here + try: + parse_and_validate_server_name(request.destination) + except ValueError: + logger.exception(f"Invalid destination: {request.destination}.") + raise FederationDeniedError(request.destination) + if timeout: _sec_timeout = timeout / 1000 else: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 50c57940f9..17e729f0c7 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -84,14 +84,13 @@ the function becomes the operation name for the span. return something_usual_and_useful -Operation names can be explicitly set for a function by passing the -operation name to ``trace`` +Operation names can be explicitly set for a function by using ``trace_with_opname``: .. code-block:: python - from synapse.logging.opentracing import trace + from synapse.logging.opentracing import trace_with_opname - @trace(opname="a_better_operation_name") + @trace_with_opname("a_better_operation_name") def interesting_badly_named_function(*args, **kwargs): # Does all kinds of cool and expected things return something_usual_and_useful @@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte # Tracing decorators -def trace(func=None, opname: Optional[str] = None): +def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]: """ - Decorator to trace a function. - Sets the operation name to that of the function's or that given - as operation_name. See the module's doc string for usage - examples. + Decorator to trace a function with a custom opname. + + See the module's doc string for usage examples. + """ - def decorator(func): + def decorator(func: Callable[P, R]) -> Callable[P, R]: if opentracing is None: return func # type: ignore[unreachable] - _opname = opname if opname else func.__name__ - if inspect.iscoroutinefunction(func): @wraps(func) - async def _trace_inner(*args, **kwargs): - with start_active_span(_opname): - return await func(*args, **kwargs) + async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R: + with start_active_span(opname): + return await func(*args, **kwargs) # type: ignore[misc] else: # The other case here handles both sync functions and those # decorated with inlineDeferred. @wraps(func) - def _trace_inner(*args, **kwargs): - scope = start_active_span(_opname) + def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R: + scope = start_active_span(opname) scope.__enter__() try: @@ -858,12 +855,21 @@ def trace(func=None, opname: Optional[str] = None): scope.__exit__(type(e), None, e.__traceback__) raise - return _trace_inner + return _trace_inner # type: ignore[return-value] - if func: - return decorator(func) - else: - return decorator + return decorator + + +def trace(func: Callable[P, R]) -> Callable[P, R]: + """ + Decorator to trace a function. + + Sets the operation name to that of the function's name. + + See the module's doc string for usage examples. + """ + + return trace_with_opname(func.__name__)(func) def tag_args(func: Callable[P, R]) -> Callable[P, R]: diff --git a/synapse/notifier.py b/synapse/notifier.py index 54b0ec4b97..c42bb8266a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -228,6 +228,7 @@ class Notifier: # Called when there are new things to stream over replication self.replication_callbacks: List[Callable[[], None]] = [] + self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = [] self._federation_client = hs.get_federation_http_client() @@ -280,6 +281,19 @@ class Notifier: """ self.replication_callbacks.append(cb) + def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None: + """Add a callback that will be called when a user joins a room. + + This only fires on genuine membership changes, e.g. "invite" -> "join". + Membership transitions like "join" -> "join" (for e.g. displayname changes) do + not trigger the callback. + + When called, the callback receives two arguments: the event ID and the room ID. + It should *not* return a Deferred - if it needs to do any asynchronous work, a + background thread should be started and wrapped with run_as_background_process. + """ + self._new_join_in_room_callbacks.append(cb) + async def on_new_room_event( self, event: EventBase, @@ -723,6 +737,10 @@ class Notifier: for cb in self.replication_callbacks: cb() + def notify_user_joined_room(self, event_id: str, room_id: str) -> None: + for cb in self._new_join_in_room_callbacks: + cb(event_id, room_id) + def notify_remote_server_up(self, server: str) -> None: """Notify any replication that a remote server has come back up""" # We call federation_sender directly rather than registering as a diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d0cc657b44..1e0ef44fc7 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -328,7 +328,7 @@ class PusherPool: return None try: - p = self.pusher_factory.create_pusher(pusher_config) + pusher = self.pusher_factory.create_pusher(pusher_config) except PusherConfigException as e: logger.warning( "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", @@ -346,23 +346,28 @@ class PusherPool: ) return None - if not p: + if not pusher: return None - appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey) + appid_pushkey = "%s:%s" % (pusher.app_id, pusher.pushkey) - byuser = self.pushers.setdefault(pusher_config.user_name, {}) + byuser = self.pushers.setdefault(pusher.user_id, {}) if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p + previous_pusher = byuser[appid_pushkey] + previous_pusher.on_stop() - synapse_pushers.labels(type(p).__name__, p.app_id).inc() + synapse_pushers.labels( + type(previous_pusher).__name__, previous_pusher.app_id + ).dec() + byuser[appid_pushkey] = pusher + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. - user_id = pusher_config.user_name - last_stream_ordering = pusher_config.last_stream_ordering + user_id = pusher.user_id + last_stream_ordering = pusher.last_stream_ordering if last_stream_ordering: have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering @@ -372,9 +377,9 @@ class PusherPool: # risk missing push. have_notifs = True - p.on_started(have_notifs) + pusher.on_started(have_notifs) - return p + return pusher async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: appid_pushkey = "%s:%s" % (app_id, pushkey) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index a4ae4040c3..561ad5bf04 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError from synapse.http.server import HttpServer, is_method_cancellable from synapse.http.site import SynapseRequest from synapse.logging import opentracing -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import trace_with_opname from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string @@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): "ascii" ) - @trace(opname="outgoing_replication_request") + @trace_with_opname("outgoing_replication_request") async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2f59245058..e4f2201c92 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.federation import send_queue from synapse.federation.sender import FederationSender from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -219,6 +219,21 @@ class ReplicationDataHandler: membership=row.data.membership, ) + # If this event is a join, make a note of it so we have an accurate + # cross-worker room rate limit. + # TODO: Erik said we should exclude rows that came from ex_outliers + # here, but I don't see how we can determine that. I guess we could + # add a flag to row.data? + if ( + row.data.type == EventTypes.Member + and row.data.membership == Membership.JOIN + and not row.data.outlier + ): + # TODO retrieve the previous state, and exclude join -> join transitions + self.notifier.notify_user_joined_room( + row.data.event_id, row.data.room_id + ) + await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 26f4fa7cfd..14b6705862 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow): relates_to: Optional[str] membership: Optional[str] rejected: bool + outlier: bool @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f0614a2897..ba2f7fa6d8 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -373,6 +373,7 @@ class UserRestServletV2(RestServlet): if ( self.hs.config.email.email_enable_notifs and self.hs.config.email.email_notif_for_new_users + and medium == "email" ): await self.pusher_pool.add_pusher( user_id=user_id, diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index bdc4a9c068..0cc87a4001 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -28,7 +28,6 @@ from synapse.api.errors import ( SynapseError, ThreepidValidationError, ) -from synapse.config.emailconfig import ThreepidBehaviour from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( @@ -64,7 +63,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.config = hs.config self.identity_handler = hs.get_identity_handler() - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self.mailer = Mailer( hs=self.hs, app_name=self.config.email.email_app_name, @@ -73,11 +72,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User password resets have been disabled due to lack of email config" - ) + if not self.config.email.can_verify_email: + logger.warning( + "User password resets have been disabled due to lack of email config" + ) raise SynapseError( 400, "Email-based password resets have been disabled on this server" ) @@ -129,35 +127,21 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.registration.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.registration.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send password reset emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_password_reset_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} + # Send password reset emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_password_reset_mail, + next_link, + ) threepid_send_requests.labels(type="email", reason="password_reset").observe( send_attempt ) - return 200, ret + # Wrap the session id in a JSON object + return 200, {"sid": sid} class PasswordRestServlet(RestServlet): @@ -349,7 +333,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.store = self.hs.get_datastores().main - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self.mailer = Mailer( hs=self.hs, app_name=self.config.email.email_app_name, @@ -358,11 +342,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) + if not self.config.email.can_verify_email: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) raise SynapseError( 400, "Adding an email to your account is disabled on this server" ) @@ -413,35 +396,20 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.registration.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.registration.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send threepid validation emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_add_threepid_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_add_threepid_mail, + next_link, + ) threepid_send_requests.labels(type="email", reason="add_threepid").observe( send_attempt ) - return 200, ret + # Wrap the session id in a JSON object + return 200, {"sid": sid} class MsisdnThreepidRequestTokenRestServlet(RestServlet): @@ -534,25 +502,18 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastores().main - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self._failure_email_template = ( self.config.email.email_add_threepid_template_failure_html ) async def on_GET(self, request: Request) -> None: - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Adding an email to your account is disabled on this server" + if not self.config.email.can_verify_email: + logger.warning( + "Adding emails have been disabled due to lack of an email config" ) - elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: raise SynapseError( - 400, - "This homeserver is not validating threepids. Use an identity server " - "instead.", + 400, "Adding an email to your account is disabled on this server" ) sid = parse_string(request, "sid", required=True) @@ -743,10 +704,12 @@ class ThreepidBindRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) + assert_params_in_dict( + body, ["id_server", "sid", "id_access_token", "client_secret"] + ) id_server = body["id_server"] sid = body["sid"] - id_access_token = body.get("id_access_token") # optional + id_access_token = body["id_access_token"] client_secret = body["client_secret"] assert_valid_client_secret(client_secret) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index ce806e3c11..eb1b85721f 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -26,7 +26,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname from synapse.types import JsonDict, StreamToken from ._base import client_patterns, interactive_auth_handler @@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() - @trace(opname="upload_keys") + @trace_with_opname("upload_keys") async def on_POST( self, request: SynapseRequest, device_id: Optional[str] ) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index dd75e40f34..0437c87d8d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -28,7 +28,7 @@ from typing import ( from typing_extensions import TypedDict -from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.api.urls import CLIENT_API_PREFIX from synapse.appservice import ApplicationService @@ -172,7 +172,13 @@ class LoginRestServlet(RestServlet): try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: - appservice = self.auth.get_appservice_by_req(request) + requester = await self.auth.get_user_by_req(request) + appservice = requester.app_service + + if appservice is None: + raise InvalidClientTokenError( + "This login method is only valid for application services" + ) if appservice.is_rate_limited(): await self._address_ratelimiter.ratelimit( diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 3644705e6a..8896f2df50 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -40,6 +40,10 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ} + if hs.config.experimental.msc2285_enabled: + self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE) + async def on_POST( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: @@ -49,13 +53,7 @@ class ReadMarkerRestServlet(RestServlet): body = parse_json_object_from_request(request) - valid_receipt_types = { - ReceiptTypes.READ, - ReceiptTypes.FULLY_READ, - ReceiptTypes.READ_PRIVATE, - } - - unrecognized_types = set(body.keys()) - valid_receipt_types + unrecognized_types = set(body.keys()) - self._known_receipt_types if unrecognized_types: # It's fine if there are unrecognized receipt types, but let's log # it to help debug clients that have typoed the receipt type. @@ -65,31 +63,25 @@ class ReadMarkerRestServlet(RestServlet): # types. logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types) - read_event_id = body.get(ReceiptTypes.READ, None) - if read_event_id: - await self.receipts_handler.received_client_receipt( - room_id, - ReceiptTypes.READ, - user_id=requester.user.to_string(), - event_id=read_event_id, - ) - - read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None) - if read_private_event_id and self.config.experimental.msc2285_enabled: - await self.receipts_handler.received_client_receipt( - room_id, - ReceiptTypes.READ_PRIVATE, - user_id=requester.user.to_string(), - event_id=read_private_event_id, - ) - - read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None) - if read_marker_event_id: - await self.read_marker_handler.received_client_read_marker( - room_id, - user_id=requester.user.to_string(), - event_id=read_marker_event_id, - ) + for receipt_type in self._known_receipt_types: + event_id = body.get(receipt_type, None) + # TODO Add validation to reject non-string event IDs. + if not event_id: + continue + + if receipt_type == ReceiptTypes.FULLY_READ: + await self.read_marker_handler.received_client_read_marker( + room_id, + user_id=requester.user.to_string(), + event_id=event_id, + ) + else: + await self.receipts_handler.received_client_receipt( + room_id, + receipt_type, + user_id=requester.user.to_string(), + event_id=event_id, + ) return 200, {} diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 4b03eb876b..409bfd43c1 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._known_receipt_types = {ReceiptTypes.READ} + if hs.config.experimental.msc2285_enabled: + self._known_receipt_types.update( + (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ) + ) + async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if self.hs.config.experimental.msc2285_enabled and receipt_type not in [ - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.FULLY_READ, - ]: + if receipt_type not in self._known_receipt_types: raise SynapseError( 400, - "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'", + f"Receipt type must be {', '.join(self._known_receipt_types)}", ) - elif ( - not self.hs.config.experimental.msc2285_enabled - and receipt_type != ReceiptTypes.READ - ): - raise SynapseError(400, "Receipt type must be 'm.read'") parse_json_object_from_request(request, allow_empty_body=False) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index e8e51a9c66..a8402cdb3a 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -31,7 +31,6 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError -from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.server import is_threepid_reserved @@ -74,7 +73,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.config = hs.config - if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.hs.config.email.can_verify_email: self.mailer = Mailer( hs=self.hs, app_name=self.config.email.email_app_name, @@ -83,13 +82,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if ( - self.hs.config.email.local_threepid_handling_disabled_due_to_email_config - ): - logger.warning( - "Email registration has been disabled due to lack of email config" - ) + if not self.hs.config.email.can_verify_email: + logger.warning( + "Email registration has been disabled due to lack of email config" + ) raise SynapseError( 400, "Email-based registration has been disabled on this server" ) @@ -138,35 +134,21 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.registration.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.registration.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send registration emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_registration_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} + # Send registration emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_registration_mail, + next_link, + ) threepid_send_requests.labels(type="email", reason="register").observe( send_attempt ) - return 200, ret + # Wrap the session id in a JSON object + return 200, {"sid": sid} class MsisdnRegisterRequestTokenRestServlet(RestServlet): @@ -260,7 +242,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastores().main - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.can_verify_email: self._failure_email_template = ( self.config.email.email_registration_template_failure_html ) @@ -270,11 +252,10 @@ class RegistrationSubmitTokenServlet(RestServlet): raise SynapseError( 400, "This medium is currently not supported for registration" ) - if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.email.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User registration via email has been disabled due to lack of email config" - ) + if not self.config.email.can_verify_email: + logger.warning( + "User registration via email has been disabled due to lack of email config" + ) raise SynapseError( 400, "Email-based registration is disabled on this server" ) diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index 37e39570f6..f7081f638e 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple, cast from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer @@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) if session_id: body = {"sessions": {session_id: body}} @@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet): user_id = requester.user.to_string() version = parse_string(request, "version", required=True) - room_keys = await self.e2e_room_keys_handler.get_room_keys( - user_id, version, room_id, session_id + room_keys = cast( + JsonDict, + await self.e2e_room_keys_handler.get_room_keys( + user_id, version, room_id, session_id + ), ) # Convert room_keys to the right format to return. @@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) ret = await self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index 3322c8ef48..1a8e9a96d4 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -19,7 +19,7 @@ from synapse.http import servlet from synapse.http.server import HttpServer from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag, trace +from synapse.logging.opentracing import set_tag, trace_with_opname from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict @@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): self.txns = HttpTransactionCache(hs) self.device_message_handler = hs.get_device_message_handler() - @trace(opname="sendToDevice") + @trace_with_opname("sendToDevice") def on_PUT( self, request: SynapseRequest, message_type: str, txn_id: str ) -> Awaitable[Tuple[int, JsonDict]]: diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8bbf35148d..c2989765ce 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -37,7 +37,7 @@ from synapse.handlers.sync import ( from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import trace_with_opname from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet): logger.debug("Event formatting complete") return 200, response_content - @trace(opname="sync.encode_response") + @trace_with_opname("sync.encode_response") async def encode_response( self, time_now: int, @@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet): ] } - @trace(opname="sync.encode_joined") + @trace_with_opname("sync.encode_joined") async def encode_joined( self, rooms: List[JoinedSyncResult], @@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet): return joined - @trace(opname="sync.encode_invited") + @trace_with_opname("sync.encode_invited") async def encode_invited( self, rooms: List[InvitedSyncResult], @@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet): return invited - @trace(opname="sync.encode_knocked") + @trace_with_opname("sync.encode_knocked") async def encode_knocked( self, rooms: List[KnockedSyncResult], @@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet): return knocked - @trace(opname="sync.encode_archived") + @trace_with_opname("sync.encode_archived") async def encode_archived( self, rooms: List[ArchivedSyncResult], diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 54a849eac9..b36c98a08e 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -109,10 +109,64 @@ class MediaInfo: class PreviewUrlResource(DirectServeJsonResource): """ - Generating URL previews is a complicated task which many potential pitfalls. - - See docs/development/url_previews.md for discussion of the design and - algorithm followed in this module. + The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API + for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix + specific additions). + + This does have trade-offs compared to other designs: + + * Pros: + * Simple and flexible; can be used by any clients at any point + * Cons: + * If each homeserver provides one of these independently, all the homeservers in a + room may needlessly DoS the target URI + * The URL metadata must be stored somewhere, rather than just using Matrix + itself to store the media. + * Matrix cannot be used to distribute the metadata between homeservers. + + When Synapse is asked to preview a URL it does the following: + + 1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the + config). + 2. Checks the URL against an in-memory cache and returns the result if it exists. (This + is also used to de-duplicate processing of multiple in-flight requests at once.) + 3. Kicks off a background process to generate a preview: + 1. Checks URL and timestamp against the database cache and returns the result if it + has not expired and was successful (a 2xx return code). + 2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it + does, update the URL to download. + 3. Downloads the URL and stores it into a file via the media storage provider + and saves the local media metadata. + 4. If the media is an image: + 1. Generates thumbnails. + 2. Generates an Open Graph response based on image properties. + 5. If the media is HTML: + 1. Decodes the HTML via the stored file. + 2. Generates an Open Graph response from the HTML. + 3. If a JSON oEmbed URL was found in the HTML via autodiscovery: + 1. Downloads the URL and stores it into a file via the media storage provider + and saves the local media metadata. + 2. Convert the oEmbed response to an Open Graph response. + 3. Override any Open Graph data from the HTML with data from oEmbed. + 4. If an image exists in the Open Graph response: + 1. Downloads the URL and stores it into a file via the media storage + provider and saves the local media metadata. + 2. Generates thumbnails. + 3. Updates the Open Graph response based on image properties. + 6. If the media is JSON and an oEmbed URL was found: + 1. Convert the oEmbed response to an Open Graph response. + 2. If a thumbnail or image is in the oEmbed response: + 1. Downloads the URL and stores it into a file via the media storage + provider and saves the local media metadata. + 2. Generates thumbnails. + 3. Updates the Open Graph response based on image properties. + 7. Stores the result in the database cache. + 4. Returns the result. + + The in-memory cache expires after 1 hour. + + Expired entries in the database cache (and their associated media files) are + deleted every 10 seconds. The default expiration time is 1 hour from download. """ isLeaf = True diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 2295adfaa7..5f725c7600 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -17,9 +17,11 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP from synapse.http.server import ( DirectServeJsonResource, + respond_with_json, set_corp_headers, set_cors_headers, ) @@ -309,6 +311,19 @@ class ThumbnailResource(DirectServeJsonResource): url_cache: True if this is from a URL cache. server_name: The server name, if this is a remote thumbnail. """ + logger.debug( + "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s", + media_id, + desired_width, + desired_height, + desired_method, + thumbnail_infos, + ) + + # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a + # different code path to handle it. + assert not self.dynamic_thumbnails + if thumbnail_infos: file_info = self._select_thumbnail( desired_width, @@ -384,8 +399,29 @@ class ThumbnailResource(DirectServeJsonResource): file_info.thumbnail.length, ) else: + # This might be because: + # 1. We can't create thumbnails for the given media (corrupted or + # unsupported file type), or + # 2. The thumbnailing process never ran or errored out initially + # when the media was first uploaded (these bugs should be + # reported and fixed). + # Note that we don't attempt to generate a thumbnail now because + # `dynamic_thumbnails` is disabled. logger.info("Failed to find any generated thumbnails") - respond_404(request) + + respond_with_json( + request, + 400, + cs_error( + "Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)" + % ( + request.postpath, + ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()), + ), + code=Codes.UNKNOWN, + ), + send_cors=True, + ) def _select_thumbnail( self, diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index 6ac9dbc7c9..b9402cfb75 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Tuple from twisted.web.server import Request from synapse.api.errors import ThreepidValidationError -from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import DirectServeHtmlResource from synapse.http.servlet import parse_string from synapse.util.stringutils import assert_valid_client_secret @@ -46,9 +45,6 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self._local_threepid_handling_disabled_due_to_email_config = ( - hs.config.email.local_threepid_handling_disabled_due_to_email_config - ) self._confirmation_email_template = ( hs.config.email.email_password_reset_template_confirmation_html ) @@ -59,8 +55,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): hs.config.email.email_password_reset_template_failure_html ) - # This resource should not be mounted if threepid behaviour is not LOCAL - assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL + # This resource should only be mounted if email validation is enabled + assert hs.config.email.can_verify_email async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: sid = parse_string(request, "sid", required=True) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 781d9f06da..e3faa52cd6 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -24,14 +24,12 @@ from typing import ( DefaultDict, Dict, FrozenSet, - Iterable, List, Mapping, Optional, Sequence, Set, Tuple, - Union, ) import attr @@ -47,6 +45,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo +from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -54,6 +53,7 @@ from synapse.util.metrics import Measure, measure_func if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -83,17 +83,23 @@ def _gen_state_id() -> str: class _StateCacheEntry: - __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] + __slots__ = ["_state", "state_group", "prev_group", "delta_ids"] def __init__( self, - state: StateMap[str], + state: Optional[StateMap[str]], state_group: Optional[int], prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ): + if state is None and state_group is None: + raise Exception("Either state or state group must be not None") + # A map from (type, state_key) to event_id. - self.state = frozendict(state) + # + # This can be None if we have a `state_group` (as then we can fetch the + # state from the DB.) + self._state = frozendict(state) if state is not None else None # the ID of a state group if one and only one is involved. # otherwise, None otherwise? @@ -102,20 +108,30 @@ class _StateCacheEntry: self.prev_group = prev_group self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None - # The `state_id` is a unique ID we generate that can be used as ID for - # this collection of state. Usually this would be the same as the - # state group, but on worker instances we can't generate a new state - # group each time we resolve state, so we generate a separate one that - # isn't persisted and is used solely for caches. - # `state_id` is either a state_group (and so an int) or a string. This - # ensures we don't accidentally persist a state_id as a stateg_group - if state_group: - self.state_id: Union[str, int] = state_group - else: - self.state_id = _gen_state_id() + async def get_state( + self, + state_storage: "StateStorageController", + state_filter: Optional["StateFilter"] = None, + ) -> StateMap[str]: + """Get the state map for this entry, either from the in-memory state or + looking up the state group in the DB. + """ + + if self._state is not None: + return self._state + + assert self.state_group is not None + + return await state_storage.get_state_ids_for_group( + self.state_group, state_filter + ) def __len__(self) -> int: - return len(self.state) + # The len should is used to estimate how large this cache entry is, for + # cache eviction purposes. This is why if `self.state` is None it's fine + # to return 1. + + return len(self._state) if self._state else 1 class StateHandler: @@ -137,23 +153,29 @@ class StateHandler: ReplicationUpdateCurrentStateRestServlet.make_client(hs) ) - async def get_current_state_ids( + async def compute_state_after_events( self, room_id: str, - latest_event_ids: Collection[str], + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, ) -> StateMap[str]: - """Get the current state, or the state at a set of events, for a room + """Fetch the state after each of the given event IDs. Resolve them and return. + + This is typically used where `event_ids` is a collection of forward extremities + in a room, intended to become the `prev_events` of a new event E. If so, the + return value of this function represents the state before E. Args: - room_id: - latest_event_ids: The forward extremities to resolve. + room_id: the room_id containing the given events. + event_ids: the events whose state should be fetched and resolved. Returns: - the state dict, mapping from (event_type, state_key) -> event_id + the state dict (a mapping from (event_type, state_key) -> event_id) which + holds the resolution of the states after the given event IDs. """ - logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return ret.state + logger.debug("calling resolve_state_groups from compute_state_after_events") + ret = await self.resolve_state_groups_for_events(room_id, event_ids) + return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_users_in_room( self, room_id: str, latest_event_ids: List[str] @@ -177,7 +199,8 @@ class StateHandler: logger.debug("calling resolve_state_groups from get_current_users_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return await self.store.get_joined_users_from_state(room_id, entry) + state = await entry.get_state(self._state_storage_controller, StateFilter.all()) + return await self.store.get_joined_users_from_state(room_id, state, entry) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] @@ -192,7 +215,8 @@ class StateHandler: The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - return await self.store.get_joined_hosts(room_id, entry) + state = await entry.get_state(self._state_storage_controller, StateFilter.all()) + return await self.store.get_joined_hosts(room_id, state, entry) async def compute_event_context( self, @@ -227,10 +251,19 @@ class StateHandler: # if state_ids_before_event: # if we're given the state before the event, then we use that - state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None - entry = None + + # .. though we need to get a state group for it. + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=None, + delta_ids=None, + current_state_ids=state_ids_before_event, + ) + ) else: # otherwise, we'll need to resolve the state across the prev_events. @@ -264,36 +297,32 @@ class StateHandler: await_full_state=False, ) - state_ids_before_event = entry.state - state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids + state_ids_before_event = None + + # We make sure that we have a state group assigned to the state. + if entry.state_group is None: + # store_state_group requires us to have either a previous state group + # (with deltas) or the complete state map. So, if we don't have a + # previous state group, load the complete state map now. + if state_group_before_event_prev_group is None: + state_ids_before_event = await entry.get_state( + self._state_storage_controller, StateFilter.all() + ) - # - # make sure that we have a state group at that point. If it's not a state event, - # that will be the state group for the new event. If it *is* a state event, - # it might get rejected (in which case we'll need to persist it with the - # previous state group) - # - - if not state_group_before_event: - state_group_before_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) - ) - - # Assign the new state group to the cached state entry. - # - # Note that this can race in that we could generate multiple state - # groups for the same state entry, but that is just inefficient - # rather than dangerous. - if entry and entry.state_group is None: entry.state_group = state_group_before_event + else: + state_group_before_event = entry.state_group # # now if it's not a state event, we're done @@ -315,13 +344,18 @@ class StateHandler: # key = (event.type, event.state_key) - if key in state_ids_before_event: - replaces = state_ids_before_event[key] - if replaces != event.event_id: - event.unsigned["replaces_state"] = replaces - state_ids_after_event = dict(state_ids_before_event) - state_ids_after_event[key] = event.event_id + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + else: + replaces_state_map = await entry.get_state( + self._state_storage_controller, StateFilter.from_types([key]) + ) + replaces = replaces_state_map.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + delta_ids = {key: event.event_id} state_group_after_event = ( @@ -330,7 +364,7 @@ class StateHandler: event.room_id, prev_group=state_group_before_event, delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + current_state_ids=None, ) ) @@ -372,9 +406,6 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self._state_storage_controller.get_state_for_groups( - state_group_ids_set - ) ( prev_group, delta_ids, @@ -382,7 +413,7 @@ class StateHandler: state_group_id ) return _StateCacheEntry( - state=state[state_group_id], + state=None, state_group=state_group_id, prev_group=prev_group, delta_ids=delta_ids, @@ -405,31 +436,6 @@ class StateHandler: ) return result - async def resolve_events( - self, - room_version: str, - state_sets: Collection[Iterable[EventBase]], - event: EventBase, - ) -> StateMap[EventBase]: - logger.info( - "Resolving state for %s with %d groups", event.room_id, len(state_sets) - ) - state_set_ids = [ - {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets - ] - - state_map = {ev.event_id: ev for st in state_sets for ev in st} - - new_state = await self._state_resolution_handler.resolve_events_with_store( - event.room_id, - room_version, - state_set_ids, - event_map=state_map, - state_res_store=StateResolutionStore(self.store), - ) - - return {key: state_map[ev_id] for key, ev_id in new_state.items()} - async def update_current_state(self, room_id: str) -> None: """Recalculates the current state for a room, and persists it. @@ -752,6 +758,12 @@ def _make_state_cache_entry( delta_ids: Optional[StateMap[str]] = None for old_group, old_state in state_groups_ids.items(): + if old_state.keys() - new_state.keys(): + # Currently we don't support deltas that remove keys from the state + # map, so we have to ignore this group as a candidate to base the + # new group on. + continue + n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b8c8dcd76b..a2f8310388 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta): cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. + Note that this function does not invalidate any remote caches, only the + local in-memory ones. Any remote invalidation must be performed before + calling this. + Args: cache_name key: Entry to invalidate. If None then invalidates the entire @@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta): if key is None: cache.invalidate_all() else: - cache.invalidate(tuple(key)) + # Prefer any local-only invalidation method. Invalidating any non-local + # cache must be be done before this. + invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) + invalidate_method(tuple(key)) def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py index 55649719f6..45101cda7a 100644 --- a/synapse/storage/controllers/__init__.py +++ b/synapse/storage/controllers/__init__.py @@ -43,4 +43,6 @@ class StorageControllers: self.persistence = None if stores.persist_events: - self.persistence = EventsPersistenceStorageController(hs, stores) + self.persistence = EventsPersistenceStorageController( + hs, stores, self.state + ) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index ea499ce0f8..cf98b0ab48 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext from synapse.logging import opentracing from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, @@ -308,7 +310,12 @@ class EventsPersistenceStorageController: current state and forward extremity changes. """ - def __init__(self, hs: "HomeServer", stores: Databases): + def __init__( + self, + hs: "HomeServer", + stores: Databases, + state_controller: StateStorageController, + ): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. @@ -325,6 +332,7 @@ class EventsPersistenceStorageController: self._process_event_persist_queue_task ) self._state_resolution_handler = hs.get_state_resolution_handler() + self._state_controller = state_controller async def _process_event_persist_queue_task( self, @@ -504,7 +512,7 @@ class EventsPersistenceStorageController: state_res_store=StateResolutionStore(self.main_store), ) - return res.state + return await res.get_state(self._state_controller, StateFilter.all()) async def _persist_event_batch( self, _room_id: str, task: _PersistEventsTask @@ -940,7 +948,8 @@ class EventsPersistenceStorageController: events_context, ) - return res.state, None, new_latest_event_ids + full_state = await res.get_state(self._state_controller) + return full_state, None, new_latest_event_ids async def _prune_extremities( self, diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index d3a44bc876..e08f956e6e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -346,7 +346,7 @@ class StateStorageController: room_id: str, prev_group: Optional[int], delta_ids: Optional[StateMap[str]], - current_state_ids: StateMap[str], + current_state_ids: Optional[StateMap[str]], ) -> int: """Store a new set of state, returning a newly assigned state group. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e21ab08515..ea672ff89e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -23,6 +23,7 @@ from time import monotonic as monotonic_time from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Collection, Dict, @@ -168,6 +169,7 @@ class LoggingDatabaseConnection: *, txn_name: Optional[str] = None, after_callbacks: Optional[List["_CallbackListEntry"]] = None, + async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None, exception_callbacks: Optional[List["_CallbackListEntry"]] = None, ) -> "LoggingTransaction": if not txn_name: @@ -178,6 +180,7 @@ class LoggingDatabaseConnection: name=txn_name, database_engine=self.engine, after_callbacks=after_callbacks, + async_after_callbacks=async_after_callbacks, exception_callbacks=exception_callbacks, ) @@ -209,6 +212,9 @@ class LoggingDatabaseConnection: # The type of entry which goes on our after_callbacks and exception_callbacks lists. _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] +_AsyncCallbackListEntry = Tuple[ + Callable[..., Awaitable], Tuple[object, ...], Dict[str, object] +] P = ParamSpec("P") R = TypeVar("R") @@ -227,6 +233,10 @@ class LoggingTransaction: that have been added by `call_after` which should be run on successful completion of the transaction. None indicates that no callbacks should be allowed to be scheduled to run. + async_after_callbacks: A list that asynchronous callbacks will be appended + to by `async_call_after` which should run, before after_callbacks, on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. exception_callbacks: A list that callbacks will be appended to that have been added by `call_on_exception` which should be run if transaction ends with an error. None indicates that no callbacks @@ -238,6 +248,7 @@ class LoggingTransaction: "name", "database_engine", "after_callbacks", + "async_after_callbacks", "exception_callbacks", ] @@ -247,12 +258,14 @@ class LoggingTransaction: name: str, database_engine: BaseDatabaseEngine, after_callbacks: Optional[List[_CallbackListEntry]] = None, + async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None, exception_callbacks: Optional[List[_CallbackListEntry]] = None, ): self.txn = txn self.name = name self.database_engine = database_engine self.after_callbacks = after_callbacks + self.async_after_callbacks = async_after_callbacks self.exception_callbacks = exception_callbacks def call_after( @@ -277,6 +290,28 @@ class LoggingTransaction: # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + def async_call_after( + self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs + ) -> None: + """Call the given asynchronous callback on the main twisted thread after + the transaction has finished (but before those added in `call_after`). + + Mostly used to invalidate remote caches after transactions. + + Note that transactions may be retried a few times if they encounter database + errors such as serialization failures. Callbacks given to `async_call_after` + will accumulate across transaction attempts and will _all_ be called once a + transaction attempt succeeds, regardless of whether previous transaction + attempts failed. Otherwise, if all transaction attempts fail, all + `call_on_exception` callbacks will be run instead. + """ + # if self.async_after_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.async_after_callbacks is not None + # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668 + self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type] + def call_on_exception( self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs ) -> None: @@ -574,6 +609,7 @@ class DatabasePool: conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], + async_after_callbacks: List[_AsyncCallbackListEntry], exception_callbacks: List[_CallbackListEntry], func: Callable[Concatenate[LoggingTransaction, P], R], *args: P.args, @@ -597,6 +633,7 @@ class DatabasePool: conn desc after_callbacks + async_after_callbacks exception_callbacks func *args @@ -659,6 +696,7 @@ class DatabasePool: cursor = conn.cursor( txn_name=name, after_callbacks=after_callbacks, + async_after_callbacks=async_after_callbacks, exception_callbacks=exception_callbacks, ) try: @@ -798,6 +836,7 @@ class DatabasePool: async def _runInteraction() -> R: after_callbacks: List[_CallbackListEntry] = [] + async_after_callbacks: List[_AsyncCallbackListEntry] = [] exception_callbacks: List[_CallbackListEntry] = [] if not current_context(): @@ -809,6 +848,7 @@ class DatabasePool: self.new_transaction, desc, after_callbacks, + async_after_callbacks, exception_callbacks, func, *args, @@ -817,13 +857,17 @@ class DatabasePool: **kwargs, ) + # We order these assuming that async functions call out to external + # systems (e.g. to invalidate a cache) and the sync functions make these + # changes on any local in-memory caches/similar, and thus must be second. + for async_callback, async_args, async_kwargs in async_after_callbacks: + await async_callback(*async_args, **async_kwargs) for after_callback, after_args, after_kwargs in after_callbacks: after_callback(*after_args, **after_kwargs) - return cast(R, result) except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) + for exception_callback, after_args, after_kwargs in exception_callbacks: + exception_callback(*after_args, **after_kwargs) raise # To handle cancellation, we ensure that `after_callback`s and diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index e284454b66..64b70a7b28 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore( device_list_summary=DeviceListUpdates(), ) - async def set_appservice_last_pos(self, pos: int) -> None: - def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None: - txn.execute( - "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) - ) + async def get_appservice_last_pos(self) -> int: + """ + Get the last stream ordering position for the appservice process. + """ - await self.db_pool.runInteraction( - "set_appservice_last_pos", set_appservice_last_pos_txn + return await self.db_pool.simple_select_one_onecol( + table="appservice_stream_position", + retcol="stream_ordering", + keyvalues={}, + desc="get_appservice_last_pos", ) - async def get_new_events_for_appservice( - self, current_id: int, limit: int - ) -> Tuple[int, List[EventBase]]: - """Get all new events for an appservice""" - - def get_new_events_for_appservice_txn( - txn: LoggingTransaction, - ) -> Tuple[int, List[str]]: - sql = ( - "SELECT e.stream_ordering, e.event_id" - " FROM events AS e" - " WHERE" - " (SELECT stream_ordering FROM appservice_stream_position)" - " < e.stream_ordering" - " AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - " LIMIT ?" - ) - - txn.execute(sql, (current_id, limit)) - rows = txn.fetchall() - - upper_bound = current_id - if len(rows) == limit: - upper_bound = rows[-1][0] - - return upper_bound, [row[1] for row in rows] + async def set_appservice_last_pos(self, pos: int) -> None: + """ + Set the last stream ordering position for the appservice process. + """ - upper_bound, event_ids = await self.db_pool.runInteraction( - "get_new_events_for_appservice", get_new_events_for_appservice_txn + await self.db_pool.simple_update_one( + table="appservice_stream_position", + keyvalues={}, + updatevalues={"stream_ordering": pos}, + desc="set_appservice_last_pos", ) - events = await self.get_events_as_list(event_ids, get_prev_content=True) - - return upper_bound, events - async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 1653a6a9b6..2367ddeea3 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -193,7 +193,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): relates_to: Optional[str], backfilled: bool, ) -> None: - self._invalidate_get_event_cache(event_id) + # This invalidates any local in-memory cached event objects, the original + # process triggering the invalidation is responsible for clearing any external + # cached objects. + self._invalidate_local_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -208,7 +211,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: - self._invalidate_get_event_cache(redacts) + self._invalidate_local_get_event_cache(redacts) # Caches which might leak edits must be invalidated for the event being # redacted. self.get_relations_for_event.invalidate((redacts,)) diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index fd3fc298b3..58177ecec1 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # changed its content in the database. We can't call # self._invalidate_cache_and_stream because self.get_event_cache isn't of the # right type. - txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + self.invalidate_get_event_cache_after_txn(txn, event.event_id) # Send that invalidation to replication so that other workers also invalidate # the event cache. self._send_invalidation_to_replication( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index adde5d0978..7a6ed332aa 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): @trace async def get_user_devices_from_cache( - self, query_list: List[Tuple[str, str]] + self, query_list: List[Tuple[str, Optional[str]]] ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 9b293475c8..60f622ad71 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -22,11 +22,14 @@ from typing import ( List, Optional, Tuple, + Union, cast, + overload, ) import attr from canonicaljson import encode_canonical_json +from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( @@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker user_devices = devices[user_id] results = [] for device_id, device in user_devices.items(): - result = {"device_id": device_id} + result: JsonDict = {"device_id": device_id} keys = device.keys if keys: @@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker rv[user_id] = {} for device_id, device_info in device_keys.items(): r = device_info.keys + if r is None: + continue + r["unsigned"] = {} display_name = device_info.display_name if display_name is not None: @@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return rv + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[True], + include_deleted_devices: Literal[True], + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ... + @trace async def get_e2e_device_keys_and_signatures( self, - query_list: List[Tuple[str, Optional[str]]], + query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: bool = False, - ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ) -> Union[ + Dict[str, Dict[str, DeviceKeyLookupResult]], + Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]], + ]: """Fetch a list of device keys Any cross-signatures made on the keys by the owner of the device are also @@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple db_autocommit = False - row = await self.db_pool.runInteraction( + claim_row = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_key, user_id, @@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker algorithm, db_autocommit=db_autocommit, ) - if row: + if claim_row: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[row[0]] = row[1] + device_results[claim_row[0]] = claim_row[1] continue # No one-time key available, so see if there's a fallback diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index eb4efbb93c..1f600f1190 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1293,7 +1293,7 @@ class PersistEventsStore: depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids - txn.call_after(self.store._invalidate_get_event_cache, event.event_id) + self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) # Then update the `stream_ordering` position to mark the latest # event as the front of the room. This should not be done for # backfilled events because backfilled events have negative @@ -1346,9 +1346,24 @@ class PersistEventsStore: event_id: outlier for event_id, outlier in txn } + logger.debug( + "_update_outliers_txn: events=%s have_persisted=%s", + [ev.event_id for ev, _ in events_and_contexts], + have_persisted, + ) + to_remove = set() for event, context in events_and_contexts: - if event.event_id not in have_persisted: + outlier_persisted = have_persisted.get(event.event_id) + logger.debug( + "_update_outliers_txn: event=%s outlier=%s outlier_persisted=%s", + event.event_id, + event.internal_metadata.is_outlier(), + outlier_persisted, + ) + + # Ignore events which we haven't persisted at all + if outlier_persisted is None: continue to_remove.add(event) @@ -1358,7 +1373,6 @@ class PersistEventsStore: # was an outlier or not - what we have is at least as good. continue - outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: # We received a copy of an event that we had already stored as # an outlier in the database. We now have some state at that event @@ -1369,7 +1383,10 @@ class PersistEventsStore: # events down /sync. In general they will be historical events, so that # doesn't matter too much, but that is not always the case. - logger.info("Updating state for ex-outlier event %s", event.event_id) + logger.info( + "_update_outliers_txn: Updating state for ex-outlier event %s", + event.event_id, + ) # insert into event_to_state_groups. try: @@ -1669,13 +1686,13 @@ class PersistEventsStore: if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - def prefill() -> None: + async def prefill() -> None: for cache_entry in to_prefill: - self.store._get_event_cache.set( + await self.store._get_event_cache.set( (cache_entry.event.event_id,), cache_entry ) - txn.call_after(prefill) + txn.async_call_after(prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: """Invalidate the caches for the redacted event. @@ -1684,7 +1701,7 @@ class PersistEventsStore: _invalidate_caches_for_event. """ assert event.redacts is not None - txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index eeca85fc94..6e8aeed7b4 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -67,6 +67,8 @@ class _BackgroundUpdates: EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows" EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index" + EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections" + @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: @@ -253,6 +255,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): replaces_index="ev_edges_id", ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS, + self._background_events_populate_state_key_rejections, + ) + async def _background_reindex_fields_sender( self, progress: JsonDict, batch_size: int ) -> int: @@ -1399,3 +1406,83 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return batch_size + + async def _background_events_populate_state_key_rejections( + self, progress: JsonDict, batch_size: int + ) -> int: + """Back-populate `events.state_key` and `events.rejection_reason""" + + min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"] + max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"] + + def _populate_txn(txn: LoggingTransaction) -> bool: + """Returns True if we're done.""" + + # first we need to find an endpoint. + # we need to find the final row in the batch of batch_size, which means + # we need to skip over (batch_size-1) rows and get the next row. + txn.execute( + """ + SELECT stream_ordering FROM events + WHERE stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering + LIMIT 1 OFFSET ? + """, + ( + min_stream_ordering_exclusive, + max_stream_ordering_inclusive, + batch_size - 1, + ), + ) + + endpoint = None + row = txn.fetchone() + if row: + endpoint = row[0] + + where_clause = "stream_ordering > ?" + args = [min_stream_ordering_exclusive] + if endpoint: + where_clause += " AND stream_ordering <= ?" + args.append(endpoint) + + # now do the updates. + txn.execute( + f""" + UPDATE events + SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id), + rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id) + WHERE ({where_clause}) + """, + args, + ) + + logger.info( + "populated new `events` columns up to %s/%i: updated %i rows", + endpoint, + max_stream_ordering_inclusive, + txn.rowcount, + ) + + if endpoint is None: + # we're done + return True + + progress["min_stream_ordering_exclusive"] = endpoint + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS, + progress, + ) + return False + + done = await self.db_pool.runInteraction( + desc="events_populate_state_key_rejections", func=_populate_txn + ) + + if done: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS + ) + + return batch_size diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index b99b107784..5914a35420 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -79,7 +79,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.lrucache import AsyncLruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -238,7 +238,9 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache: AsyncLruCache[ + Tuple[str], EventCacheEntry + ] = AsyncLruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -292,25 +294,6 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - async def get_received_ts(self, event_id: str) -> Optional[int]: - """Get received_ts (when it was persisted) for the event. - - Raises an exception for unknown events. - - Args: - event_id: The event ID to query. - - Returns: - Timestamp in milliseconds, or None for events that were persisted - before received_ts was implemented. - """ - return await self.db_pool.simple_select_one_onecol( - table="events", - keyvalues={"event_id": event_id}, - retcol="received_ts", - desc="get_received_ts", - ) - async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. @@ -617,7 +600,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: map from event id to result """ - event_entry_map = self._get_events_from_cache( + event_entry_map = await self._get_events_from_cache( event_ids, ) @@ -729,12 +712,46 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: - self._get_event_cache.invalidate((event_id,)) + def invalidate_get_event_cache_after_txn( + self, txn: LoggingTransaction, event_id: str + ) -> None: + """ + Prepares a database transaction to invalidate the get event cache for a given + event ID when executed successfully. This is achieved by attaching two callbacks + to the transaction, one to invalidate the async cache and one for the in memory + sync cache (importantly called in that order). + + Arguments: + txn: the database transaction to attach the callbacks to + event_id: the event ID to be invalidated from caches + """ + + txn.async_call_after(self._invalidate_async_get_event_cache, event_id) + txn.call_after(self._invalidate_local_get_event_cache, event_id) + + async def _invalidate_async_get_event_cache(self, event_id: str) -> None: + """ + Invalidates an event in the asyncronous get event cache, which may be remote. + + Arguments: + event_id: the event ID to invalidate + """ + + await self._get_event_cache.invalidate((event_id,)) + + def _invalidate_local_get_event_cache(self, event_id: str) -> None: + """ + Invalidates an event in local in-memory get event caches. + + Arguments: + event_id: the event ID to invalidate + """ + + self._get_event_cache.invalidate_local((event_id,)) self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) - def _get_events_from_cache( + async def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: """Fetch events from the caches. @@ -749,7 +766,7 @@ class EventsWorkerStore(SQLBaseStore): for event_id in events: # First check if it's in the event cache - ret = self._get_event_cache.get( + ret = await self._get_event_cache.get( (event_id,), None, update_metrics=update_metrics ) if ret: @@ -771,7 +788,7 @@ class EventsWorkerStore(SQLBaseStore): # We add the entry back into the cache as we want to keep # recently queried events in the cache. - self._get_event_cache.set((event_id,), cache_entry) + await self._get_event_cache.set((event_id,), cache_entry) return event_map @@ -965,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore): } row_dict = self.db_pool.new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch + conn, + "do_fetch", + [], + [], + [], + self._fetch_event_rows, + events_to_fetch, ) # We only want to resolve deferreds from the main thread @@ -1148,7 +1171,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.set((event_id,), cache_entry) + await self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry if not redacted_event: @@ -1382,7 +1405,9 @@ class EventsWorkerStore(SQLBaseStore): # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) + (rid, eid) + for (rid, eid) in keys + if await self._get_event_cache.contains((eid,)) } results = dict.fromkeys(cache_results, True) remaining = [k for k in keys if k not in cache_results] @@ -1465,7 +1490,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: """Returns new events, for the Events replication stream Args: @@ -1481,10 +1506,11 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," + " e.outlier" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events AS se USING (event_id)" @@ -1498,7 +1524,8 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name, limit)) return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + List[Tuple[int, str, str, str, str, str, str, str, bool, bool]], + txn.fetchall(), ) return await self.db_pool.runInteraction( @@ -1507,7 +1534,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: """Returns de-outliered events, for the Events replication stream Args: @@ -1522,11 +1549,14 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," + " e.outlier" " FROM events AS e" + # NB: the next line (inner join) is what makes this query different from + # get_all_new_forward_event_rows. " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events AS se USING (event_id)" @@ -1541,7 +1571,8 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, instance_name)) return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + List[Tuple[int, str, str, str, str, str, str, str, bool, bool]], + txn.fetchall(), ) return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 9a63f953fb..efd136a864 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): "initialise_mau_threepids", [], [], + [], self._initialise_reserved_users, hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 87b0d09039..f6822707e4 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -19,6 +19,8 @@ from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.engines._base import IsolationLevel from synapse.types import RoomStreamToken logger = logging.getLogger(__name__) @@ -302,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) - self._invalidate_get_event_cache(event_id) + self.invalidate_get_event_cache_after_txn(txn, event_id) logger.info("[purge] done") @@ -317,11 +319,38 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): Returns: The list of state groups to delete. """ - return await self.db_pool.runInteraction( - "purge_room", self._purge_room_txn, room_id + + # This first runs the purge transaction with READ_COMMITTED isolation level, + # meaning any new rows in the tables will not trigger a serialization error. + # We then run the same purge a second time without this isolation level to + # purge any of those rows which were added during the first. + + state_groups_to_delete = await self.db_pool.runInteraction( + "purge_room", + self._purge_room_txn, + room_id=room_id, + isolation_level=IsolationLevel.READ_COMMITTED, + ) + + state_groups_to_delete.extend( + await self.db_pool.runInteraction( + "purge_room", + self._purge_room_txn, + room_id=room_id, + ), ) + return state_groups_to_delete + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: + # This collides with event persistence so we cannot write new events and metadata into + # a room while deleting it or this transaction will fail. + if isinstance(self.database_engine, PostgresEngine): + txn.execute( + "SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE", + (room_id,), + ) + # First, fetch all the state groups that should be deleted, before # we delete that information. txn.execute( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 86649c1e6c..768f95d16c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -228,6 +228,7 @@ class PushRulesWorkerStore( iterable=user_ids, retcols=("*",), desc="bulk_get_push_rules", + batch_size=1000, ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 13d6a1d5c0..d6d485507b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -175,7 +175,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): rooms.creator, state.encryption, state.is_federatable AS federatable, rooms.is_public AS public, state.join_rules, state.guest_access, state.history_visibility, curr.current_state_events AS state_events, - state.avatar, state.topic + state.avatar, state.topic, state.room_type FROM rooms LEFT JOIN room_stats_state state USING (room_id) LEFT JOIN room_stats_current curr USING (room_id) @@ -596,7 +596,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room, rooms.room_version, rooms.creator, state.encryption, state.is_federatable, rooms.is_public, state.join_rules, - state.guest_access, state.history_visibility, curr.current_state_events + state.guest_access, state.history_visibility, curr.current_state_events, + state.room_type FROM room_stats_state state INNER JOIN room_stats_current curr USING (room_id) INNER JOIN rooms USING (room_id) @@ -646,6 +647,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): "guest_access": room[11], "history_visibility": room[12], "state_events": room[13], + "room_type": room[14], } ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 0b5e4e4254..df6b82660e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -31,7 +31,6 @@ import attr from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import ( run_as_background_process, @@ -244,7 +243,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn: LoggingTransaction, ) -> Dict[str, ProfileInfo]: clause, ids = make_in_list_sql_clause( - self.database_engine, "m.user_id", user_ids + self.database_engine, "c.state_key", user_ids ) sql = """ @@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): return shared_room_ids or frozenset() - async def get_joined_users_from_context( - self, event: EventBase, context: EventContext - ) -> Dict[str, ProfileInfo]: - state_group: Union[object, int] = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = await context.get_current_state_ids() - assert current_state_ids is not None - assert state_group is not None - return await self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, event=event, context=context - ) - async def get_joined_users_from_state( - self, room_id: str, state_entry: "_StateCacheEntry" + self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" ) -> Dict[str, ProfileInfo]: state_group: Union[object, int] = state_entry.state_group if not state_group: @@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore): assert state_group is not None with Measure(self._clock, "get_joined_users_from_state"): return await self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry + room_id, state_group, state, context=state_entry ) - @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) + @cached(num_args=2, iterable=True, max_entries=100000) async def _get_joined_users_from_context( self, room_id: str, state_group: Union[object, int], current_state_ids: StateMap[str], - cache_context: _CacheContext, event: Optional[EventBase] = None, - context: Optional[Union[EventContext, "_StateCacheEntry"]] = None, + context: Optional["_StateCacheEntry"] = None, ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states @@ -863,7 +843,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) + event_map = await self._get_events_from_cache( + member_event_ids, update_metrics=False + ) missing_member_event_ids = [] for event_id in member_event_ids: @@ -922,7 +904,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): iterable=event_ids, retcols=("user_id", "display_name", "avatar_url", "event_id"), keyvalues={"membership": Membership.JOIN}, - batch_size=500, + batch_size=1000, desc="_get_joined_profiles_from_event_ids", ) @@ -1017,7 +999,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) async def get_joined_hosts( - self, room_id: str, state_entry: "_StateCacheEntry" + self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: state_group: Union[object, int] = state_entry.state_group if not state_group: @@ -1030,7 +1012,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): assert state_group is not None with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( - room_id, state_group, state_entry=state_entry + room_id, state_group, state, state_entry=state_entry ) @cached(num_args=2, max_entries=10000, iterable=True) @@ -1038,6 +1020,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): self, room_id: str, state_group: Union[object, int], + state: StateMap[str], state_entry: "_StateCacheEntry", ) -> FrozenSet[str]: # We don't use `state_group`, it's there so that we can cache based on @@ -1093,7 +1076,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. joined_users = await self.get_joined_users_from_state( - room_id, state_entry + room_id, state, state_entry ) cache.hosts_to_joined_users = {} diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 3a1df7776c..2590b52f73 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1022,8 +1022,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int - ) -> Tuple[int, List[EventBase]]: + self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False + ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. @@ -1032,19 +1032,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return + get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events), where `next_id` is the next value to - pass as `from_id` (it will either be the stream_ordering of the - last returned event, or, if fewer than `limit` events were found, - the `current_id`). + A tuple of (next_id, events, event_to_received_ts), where `next_id` + is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` + events were found, the `current_id`). The `event_to_received_ts` is + a dictionary mapping event ID to the event `received_ts`. """ def get_all_new_events_stream_txn( txn: LoggingTransaction, - ) -> Tuple[int, List[str]]: + ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( - "SELECT e.stream_ordering, e.event_id" + "SELECT e.stream_ordering, e.event_id, e.received_ts" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" @@ -1059,15 +1061,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if len(rows) == limit: upper_bound = rows[-1][0] - return upper_bound, [row[1] for row in rows] + event_to_received_ts: Dict[str, Optional[int]] = { + row[1]: row[2] for row in rows + } + return upper_bound, event_to_received_ts - upper_bound, event_ids = await self.db_pool.runInteraction( + upper_bound, event_to_received_ts = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = await self.get_events_as_list(event_ids) + events = await self.get_events_as_list( + event_to_received_ts.keys(), + get_prev_content=get_prev_content, + ) - return upper_bound, events + return upper_bound, events, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index fa9eadaca7..a7fcc564a9 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -24,6 +24,7 @@ from synapse.storage.database import ( from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.types import MutableStateMap, StateMap +from synapse.util.caches import intern_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -136,7 +137,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): txn.execute(sql % (where_clause,), args) for row in txn: typ, state_key, event_id = row - key = (typ, state_key) + key = (intern_string(typ), intern_string(state_key)) results[group][key] = event_id else: max_entries_returned = state_filter.max_entries_returned() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 609a2b88bf..afbc85ad0c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -400,14 +400,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): room_id: str, prev_group: Optional[int], delta_ids: Optional[StateMap[str]], - current_state_ids: StateMap[str], + current_state_ids: Optional[StateMap[str]], ) -> int: """Store a new set of state, returning a newly assigned state group. + At least one of `current_state_ids` and `prev_group` must be provided. Whenever + `prev_group` is not None, `delta_ids` must also not be None. + Args: event_id: The event ID for which the state was calculated room_id - prev_group: A previous state group for the room, optional. + prev_group: A previous state group for the room. delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. @@ -418,10 +421,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): The state group ID """ - def _store_state_group_txn(txn: LoggingTransaction) -> int: - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") + if prev_group is None and current_state_ids is None: + raise Exception("current_state_ids and prev_group can't both be None") + + if prev_group is not None and delta_ids is None: + raise Exception("delta_ids is None when prev_group is not None") + + def insert_delta_group_txn( + txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str] + ) -> Optional[int]: + """Try and persist the new group as a delta. + + Requires that we have the state as a delta from a previous state group. + + Returns: + The state group if successfully created, or None if the state + needs to be persisted as a full state. + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + # if the chain of state group deltas is going too long, we fall back to + # persisting a complete state group. + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + return None state_group = self._state_group_seq_gen.get_next_id_txn(txn) @@ -431,51 +465,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): values={"id": state_group, "room_id": room_id, "event_id": event_id}, ) - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - assert delta_ids is not None - - self.db_pool.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) + self.db_pool.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_state", - keys=("state_group", "room_id", "type", "state_key", "event_id"), - values=[ - (state_group, room_id, key[0], key[1], state_id) - for key, state_id in delta_ids.items() - ], - ) - else: - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_state", - keys=("state_group", "room_id", "type", "state_key", "event_id"), - values=[ - (state_group, room_id, key[0], key[1], state_id) - for key, state_id in current_state_ids.items() - ], - ) + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + (state_group, room_id, key[0], key[1], state_id) + for key, state_id in delta_ids.items() + ], + ) + + return state_group + + def insert_full_state_txn( + txn: LoggingTransaction, current_state_ids: StateMap[str] + ) -> int: + """Persist the full state, returning the new state group.""" + state_group = self._state_group_seq_gen.get_next_id_txn(txn) + + self.db_pool.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + (state_group, room_id, key[0], key[1], state_id) + for key, state_id in current_state_ids.items() + ], + ) # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map @@ -491,7 +519,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_members_cache.update, self._state_group_members_cache.sequence, key=state_group, - value=dict(current_member_state_ids), + value=current_member_state_ids, ) current_non_member_state_ids = { @@ -503,13 +531,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_non_member_state_ids), + value=current_non_member_state_ids, ) return state_group + if prev_group is not None: + state_group = await self.db_pool.runInteraction( + "store_state_group.insert_delta_group", + insert_delta_group_txn, + prev_group, + delta_ids, + ) + if state_group is not None: + return state_group + + # We're going to persist the state as a complete group rather than + # a delta, so first we need to ensure we have loaded the state map + # from the database. + if current_state_ids is None: + assert prev_group is not None + assert delta_ids is not None + groups = await self._get_state_for_groups([prev_group]) + current_state_ids = dict(groups[prev_group]) + current_state_ids.update(delta_ids) + return await self.db_pool.runInteraction( - "store_state_group", _store_state_group_txn + "store_state_group.insert_full_state", + insert_full_state_txn, + current_state_ids, ) async def purge_unreferenced_state_groups( diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index dc237e3032..a9a88c8bfd 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -74,13 +74,14 @@ Changes in SCHEMA_VERSION = 71: Changes in SCHEMA_VERSION = 72: - event_edges.(room_id, is_state) are no longer written to. + - Tables related to groups are dropped. """ SCHEMA_COMPAT_VERSION = ( - # We no longer maintain `event_edges.room_id`, so synapses with SCHEMA_VERSION < 71 - # will break. - 71 + # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72 + # could break. + 72 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py new file mode 100644 index 0000000000..55a5d092cc --- /dev/null +++ b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py @@ -0,0 +1,47 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.storage.types import Cursor + + +def run_create(cur: Cursor, database_engine, *args, **kwargs): + """Add a bg update to populate the `state_key` and `rejection_reason` columns of `events`""" + + # we know that any new events will have the columns populated (and that has been + # the case since schema_version 68, so there is no chance of rolling back now). + # + # So, we only need to make sure that existing rows are updated. We read the + # current min and max stream orderings, since that is guaranteed to include all + # the events that were stored before the new columns were added. + cur.execute("SELECT MIN(stream_ordering), MAX(stream_ordering) FROM events") + (min_stream_ordering, max_stream_ordering) = cur.fetchone() + + if min_stream_ordering is None: + # no rows, nothing to do. + return + + cur.execute( + "INSERT into background_updates (ordering, update_name, progress_json)" + " VALUES (7203, 'events_populate_state_key_rejections', ?)", + ( + json.dumps( + { + "min_stream_ordering_exclusive": min_stream_ordering - 1, + "max_stream_ordering_inclusive": max_stream_ordering, + } + ), + ), + ) diff --git a/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql new file mode 100644 index 0000000000..0da668aa3a --- /dev/null +++ b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql @@ -0,0 +1,17 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- event_reference_hashes is unused, so we can drop it +DROP TABLE event_reference_hashes; diff --git a/synapse/storage/schema/main/delta/72/03remove_groups.sql b/synapse/storage/schema/main/delta/72/03remove_groups.sql new file mode 100644 index 0000000000..b7c5894de8 --- /dev/null +++ b/synapse/storage/schema/main/delta/72/03remove_groups.sql @@ -0,0 +1,31 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Remove the tables which powered the unspecced groups/communities feature. +DROP TABLE IF EXISTS group_attestations_remote; +DROP TABLE IF EXISTS group_attestations_renewals; +DROP TABLE IF EXISTS group_invites; +DROP TABLE IF EXISTS group_roles; +DROP TABLE IF EXISTS group_room_categories; +DROP TABLE IF EXISTS group_rooms; +DROP TABLE IF EXISTS group_summary_roles; +DROP TABLE IF EXISTS group_summary_room_categories; +DROP TABLE IF EXISTS group_summary_rooms; +DROP TABLE IF EXISTS group_summary_users; +DROP TABLE IF EXISTS group_users; +DROP TABLE IF EXISTS groups; +DROP TABLE IF EXISTS local_group_membership; +DROP TABLE IF EXISTS local_group_updates; +DROP TABLE IF EXISTS remote_profile_cache; diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index 211437cfaa..466e5137f2 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -166,6 +166,7 @@ class PartialCurrentStateTracker: logger.info( "Awaiting un-partial-stating of room %s", room_id, + stack_info=True, ) await make_deferred_yieldable(d) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 8ed5325c5d..31f41fec82 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -730,3 +730,41 @@ class LruCache(Generic[KT, VT]): # This happens e.g. in the sync code where we have an expiring cache of # lru caches. self.clear() + + +class AsyncLruCache(Generic[KT, VT]): + """ + An asynchronous wrapper around a subset of the LruCache API. + + On its own this doesn't change the behaviour but allows subclasses that + utilize external cache systems that require await behaviour to be created. + """ + + def __init__(self, *args, **kwargs): # type: ignore + self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs) + + async def get( + self, key: KT, default: Optional[T] = None, update_metrics: bool = True + ) -> Optional[VT]: + return self._lru_cache.get(key, update_metrics=update_metrics) + + async def set(self, key: KT, value: VT) -> None: + self._lru_cache.set(key, value) + + async def invalidate(self, key: KT) -> None: + # This method should invalidate any external cache and then invalidate the LruCache. + return self._lru_cache.invalidate(key) + + def invalidate_local(self, key: KT) -> None: + """Remove an entry from the local cache + + This variant of `invalidate` is useful if we know that the external + cache has already been invalidated. + """ + return self._lru_cache.invalidate(key) + + async def contains(self, key: KT) -> bool: + return self._lru_cache.contains(key) + + async def clear(self) -> None: + self._lru_cache.clear() diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 18649c2c05..c86f783c5b 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase): # Check that we get rate limited after using that token. self.assertFalse(consume_at(11.1)) + + def test_record_action_which_doesnt_fill_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe two actions, leaving room in the bucket for one more. + limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0) + + # We should be able to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + def test_record_action_which_fills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe three actions, filling up the bucket. + limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0) + + # We should be unable to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be able to take one action... + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + def test_record_action_which_overfills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe four actions, exceeding the bucket. + limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0) + + # We should be prevented from taking a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be unable to take an action + # because the bucket is still full. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + # But after another 10 seconds we leak a second token, giving us room for + # action. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=20.0) + ) + self.assertTrue(success) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9f1115dd23..c6dd99316a 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -50,7 +51,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -62,7 +63,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 268a48d7ba..50e376f695 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -22,6 +22,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -38,42 +39,46 @@ class FederationClientTest(FederatingHomeserverTestCase): self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True) homeserver.get_federation_http_client().agent = self._mock_agent - def test_get_room_state(self): - creator = f"@creator:{self.OTHER_SERVER_NAME}" - test_room_id = "!room_id" + # Move clock up to somewhat realistic time so the PDU destination retry + # works (`now` needs to be larger than `0 + PDU_RETRY_TIME_MS`). + self.reactor.advance(1000000000) + + self.creator = f"@creator:{self.OTHER_SERVER_NAME}" + self.test_room_id = "!room_id" + def test_get_room_state(self): # mock up some events to use in the response. # In real life, these would have things in `prev_events` and `auth_events`, but that's # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. - create_event_dict = self.add_hashes_and_signatures( + create_event_dict = self.add_hashes_and_signatures_from_other_server( { - "room_id": test_room_id, + "room_id": self.test_room_id, "type": "m.room.create", "state_key": "", - "sender": creator, - "content": {"creator": creator}, + "sender": self.creator, + "content": {"creator": self.creator}, "prev_events": [], "auth_events": [], "origin_server_ts": 500, } ) - member_event_dict = self.add_hashes_and_signatures( + member_event_dict = self.add_hashes_and_signatures_from_other_server( { - "room_id": test_room_id, + "room_id": self.test_room_id, "type": "m.room.member", - "sender": creator, - "state_key": creator, + "sender": self.creator, + "state_key": self.creator, "content": {"membership": "join"}, "prev_events": [], "auth_events": [], "origin_server_ts": 600, } ) - pl_event_dict = self.add_hashes_and_signatures( + pl_event_dict = self.add_hashes_and_signatures_from_other_server( { - "room_id": test_room_id, + "room_id": self.test_room_id, "type": "m.room.power_levels", - "sender": creator, + "sender": self.creator, "state_key": "", "content": {}, "prev_events": [], @@ -102,8 +107,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # now fire off the request state_resp, auth_resp = self.get_success( self.hs.get_federation_client().get_room_state( - "yet_another_server", - test_room_id, + "yet.another.server", + self.test_room_id, "event_id", RoomVersions.V9, ) @@ -112,7 +117,7 @@ class FederationClientTest(FederatingHomeserverTestCase): # check the right call got made to the agent self._mock_agent.request.assert_called_once_with( b"GET", - b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + b"matrix://yet.another.server/_matrix/federation/v1/state/%21room_id?event_id=event_id", headers=mock.ANY, bodyProducer=None, ) @@ -130,6 +135,102 @@ class FederationClientTest(FederatingHomeserverTestCase): ["m.room.create", "m.room.member", "m.room.power_levels"], ) + def test_get_pdu_returns_nothing_when_event_does_not_exist(self): + """No event should be returned when the event does not exist""" + remote_pdu = self.get_success( + self.hs.get_federation_client().get_pdu( + ["yet.another.server"], + "event_should_not_exist", + RoomVersions.V9, + ) + ) + self.assertEqual(remote_pdu, None) + + def test_get_pdu(self): + """Test to make sure an event is returned by `get_pdu()`""" + self._get_pdu_once() + + def test_get_pdu_event_from_cache_is_pristine(self): + """Test that modifications made to events returned by `get_pdu()` + do not propagate back to to the internal cache (events returned should + be a copy). + """ + + # Get the PDU in the cache + remote_pdu = self._get_pdu_once() + + # Modify the the event reference. + # This change should not make it back to the `_get_pdu_cache`. + remote_pdu.internal_metadata.outlier = True + + # Get the event again. This time it should read it from cache. + remote_pdu2 = self.get_success( + self.hs.get_federation_client().get_pdu( + ["yet.another.server"], + remote_pdu.event_id, + RoomVersions.V9, + ) + ) + + # Sanity check that we are working against the same event + self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id) + + # Make sure the event does not include modification from earlier + self.assertIsNotNone(remote_pdu2) + self.assertEqual(remote_pdu2.internal_metadata.outlier, False) + + def _get_pdu_once(self) -> EventBase: + """Retrieve an event via `get_pdu()` and assert that an event was returned. + Also used to prime the cache for subsequent test logic. + """ + message_event_dict = self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.test_room_id, + "type": "m.room.message", + "sender": self.creator, + "state_key": "", + "content": {}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 700, + "depth": 10, + } + ) + + # mock up the response, and have the agent return it + self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( + _mock_response( + { + "origin": "yet.another.server", + "origin_server_ts": 900, + "pdus": [ + message_event_dict, + ], + } + ) + ) + + remote_pdu = self.get_success( + self.hs.get_federation_client().get_pdu( + ["yet.another.server"], + "event_id", + RoomVersions.V9, + ) + ) + + # check the right call got made to the agent + self._mock_agent.request.assert_called_once_with( + b"GET", + b"matrix://yet.another.server/_matrix/federation/v1/event/event_id", + headers=mock.ANY, + bodyProducer=None, + ) + + self.assertIsNotNone(remote_pdu) + self.assertEqual(remote_pdu.internal_metadata.outlier, False) + + return remote_pdu + def _mock_response(resp: JsonDict): body = json.dumps(resp).encode("utf-8") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 413b3c9426..3a6ef221ae 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from parameterized import parameterized @@ -20,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin @@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,) ) - self.assertEqual(403, channel.code, channel.result) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -148,13 +148,13 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): tok2 = self.login("fozzie", "bear") self.helper.join(self._room_id, second_member_user_id, tok=tok2) - def _make_join(self, user_id) -> JsonDict: + def _make_join(self, user_id: str) -> JsonDict: channel = self.make_signed_federation_request( "GET", f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"?ver={DEFAULT_ROOM_VERSION}", ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body def test_send_join(self): @@ -163,18 +163,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): join_result = self._make_join(joining_user) join_event_dict = join_result["event"] - add_hashes_and_signatures( - KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + self.add_hashes_and_signatures_from_other_server( join_event_dict, - signature_name=self.OTHER_SERVER_NAME, - signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], ) channel = self.make_signed_federation_request( "PUT", f"/_matrix/federation/v2/send_join/{self._room_id}/x", content=join_event_dict, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # we should get complete room state back returned_state = [ @@ -220,18 +218,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): join_result = self._make_join(joining_user) join_event_dict = join_result["event"] - add_hashes_and_signatures( - KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + self.add_hashes_and_signatures_from_other_server( join_event_dict, - signature_name=self.OTHER_SERVER_NAME, - signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], ) channel = self.make_signed_federation_request( "PUT", f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", content=join_event_dict, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # expect a reduced room state returned_state = [ @@ -264,6 +260,67 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + def test_make_join_respects_room_join_rate_limit(self) -> None: + # In the test setup, two users join the room. Since the rate limiter burst + # count is 3, a new make_join request to the room should be accepted. + + joining_user = "@ronniecorbett:" + self.OTHER_SERVER_NAME + self._make_join(joining_user) + + # Now have a new local user join the room. This saturates the rate limiter + # bucket, so the next make_join should be denied. + new_local_user = self.register_user("animal", "animal") + token = self.login("animal", "animal") + self.helper.join(self._room_id, new_local_user, tok=token) + + joining_user = "@ronniebarker:" + self.OTHER_SERVER_NAME + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self._room_id}/{joining_user}" + f"?ver={DEFAULT_ROOM_VERSION}", + ) + self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None: + # Make two make_join requests up front. (These are rate limited, but do not + # contribute to the rate limit.) + join_event_dicts = [] + for i in range(2): + joining_user = f"@misspiggy{i}:{self.OTHER_SERVER_NAME}" + join_result = self._make_join(joining_user) + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) + join_event_dicts.append(join_event_dict) + + # In the test setup, two users join the room. Since the rate limiter burst + # count is 3, the first send_join should be accepted... + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/join0", + content=join_event_dicts[0], + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # ... but the second should be denied. + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/join1", + content=join_event_dicts[1], + ) + self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) + + # NB: we could write a test which checks that the send_join event is seen + # by other workers over replication, and that they update their rate limit + # buckets accordingly. I'm going to assume that the join event gets sent over + # replication, at which point the tests.handlers.room_member test + # test_local_users_joining_on_another_worker_contribute_to_rate_limit + # is probably sufficient to reassure that the bucket is updated. + def _create_acl_event(content): return make_event_from_dict( diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index d21c11b716..0d048207b7 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict +from http import HTTPStatus from typing import Dict, List from synapse.api.constants import EventTypes, JoinRules, Membership @@ -255,7 +256,7 @@ class FederationKnockingTestCase( RoomVersions.V7.identifier, ), ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -293,7 +294,7 @@ class FederationKnockingTestCase( % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index d96d5aa138..b17af2725b 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -50,7 +50,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastores.return_value = Mock(main=self.mock_store) - self.mock_store.get_received_ts.return_value = make_awaitable(0) + self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( None @@ -76,9 +76,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [])), - make_awaitable((1, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [], {})), + make_awaitable((1, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) @@ -95,8 +95,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -112,8 +112,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 9b9c11fab7..8a0bb91f40 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, cast +from typing import cast from unittest import TestCase from twisted.test.proto_helpers import MemoryReactor @@ -50,8 +50,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage_controllers().state - self._event_auth_handler = hs.get_event_auth_handler() return hs def test_exchange_revoked_invite(self) -> None: @@ -256,7 +254,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ] for _ in range(0, 8): event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "origin_server_ts": 1, "type": "m.room.message", @@ -314,142 +312,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) self.get_success(d) - def test_backfill_floating_outlier_membership_auth(self) -> None: - """ - As the local homeserver, check that we can properly process a federated - event from the OTHER_SERVER with auth_events that include a floating - membership event from the OTHER_SERVER. - - Regression test, see #10439. - """ - OTHER_SERVER = "otherserver" - OTHER_USER = "@otheruser:" + OTHER_SERVER - - # create the room - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - room_id = self.helper.create_room_as( - room_creator=user_id, - is_public=True, - tok=tok, - extra_content={ - "preset": "public_chat", - }, - ) - room_version = self.get_success(self.store.get_room_version(room_id)) - - prev_event_ids = self.get_success(self.store.get_prev_events_for_room(room_id)) - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = self.get_success(self.store.get_max_depth_of(prev_event_ids)) - # mapping from (type, state_key) -> state_event_id - assert most_recent_prev_event_id is not None - prev_state_map = self.get_success( - self.state_storage_controller.get_state_ids_for_event( - most_recent_prev_event_id - ) - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - auth_events = list( - self.get_success(self.store.get_events(auth_event_ids)).values() - ) - - # build a floating outlier member state event - fake_prev_event_id = "$" + random_string(43) - member_event_dict = { - "type": EventTypes.Member, - "content": { - "membership": "join", - }, - "state_key": OTHER_USER, - "room_id": room_id, - "sender": OTHER_USER, - "depth": most_recent_prev_event_depth, - "prev_events": [fake_prev_event_id], - "origin_server_ts": self.clock.time_msec(), - "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}}, - } - builder = self.hs.get_event_builder_factory().for_room_version( - room_version, member_event_dict - ) - member_event = self.get_success( - builder.build( - prev_event_ids=member_event_dict["prev_events"], - auth_event_ids=self._event_auth_handler.compute_auth_events( - builder, - prev_state_map, - for_verification=False, - ), - depth=member_event_dict["depth"], - ) - ) - # Override the signature added from "test" homeserver that we created the event with - member_event.signatures = member_event_dict["signatures"] - - # Add the new member_event to the StateMap - updated_state_map = dict(prev_state_map) - updated_state_map[ - (member_event.type, member_event.state_key) - ] = member_event.event_id - auth_events.append(member_event) - - # build and send an event authed based on the member event - message_event_dict = { - "type": EventTypes.Message, - "content": {}, - "room_id": room_id, - "sender": OTHER_USER, - "depth": most_recent_prev_event_depth, - "prev_events": prev_event_ids.copy(), - "origin_server_ts": self.clock.time_msec(), - "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}}, - } - builder = self.hs.get_event_builder_factory().for_room_version( - room_version, message_event_dict - ) - message_event = self.get_success( - builder.build( - prev_event_ids=message_event_dict["prev_events"], - auth_event_ids=self._event_auth_handler.compute_auth_events( - builder, - updated_state_map, - for_verification=False, - ), - depth=message_event_dict["depth"], - ) - ) - # Override the signature added from "test" homeserver that we created the event with - message_event.signatures = message_event_dict["signatures"] - - # Stub the /event_auth response from the OTHER_SERVER - async def get_event_auth( - destination: str, room_id: str, event_id: str - ) -> List[EventBase]: - return [ - event_from_pdu_json(ae.get_pdu_json(), room_version=room_version) - for ae in auth_events - ] - - self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment] - - with LoggingContext("receive_pdu"): - # Fake the OTHER_SERVER federating the message event over to our local homeserver - d = run_in_background( - self.hs.get_federation_event_handler().on_receive_pdu, - OTHER_SERVER, - message_event, - ) - self.get_success(d) - - # Now try and get the events on our local homeserver - stored_event = self.get_success( - self.store.get_event(message_event.event_id, allow_none=True) - ) - self.assertTrue(stored_event is not None) - @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 4b1a8f04db..51c8dd6498 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -104,7 +104,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # mock up a load of state events which we are missing state_events = [ make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_state_type", "state_key": f"state_{i}", @@ -131,7 +131,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # Depending on the test, we either persist this upfront (as an outlier), # or let the server request it. prev_event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_regular_type", "room_id": room_id, @@ -165,7 +165,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # mock up a regular event to pass into _process_pulled_event pulled_event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_regular_type", "room_id": room_id, diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 82b3bb3b73..4c62449c89 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -14,6 +14,7 @@ """Tests for the password_auth_provider interface""" +from http import HTTPStatus from typing import Any, Type, Union from unittest.mock import Mock @@ -188,14 +189,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.reset_mock() # login with mxid should work too channel = self._send_password_login("@u:bz", "p") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:bz", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") mock_password_provider.reset_mock() @@ -204,7 +205,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with( "@ USER🙂NAME :test", " pASS😢word " @@ -258,10 +259,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 403, channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @@ -382,7 +383,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_password.assert_not_called() @override_config(legacy_providers_config(LegacyCustomAuthProvider)) @@ -406,14 +407,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login with missing param should be rejected channel = self._send_login("test.login_type", "u") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) channel = self._send_login("test.login_type", "u", test_field="y") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} @@ -427,7 +428,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( " USER🙂NAME ", "test.login_type", {"test_field": " abc "} @@ -510,7 +511,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@user:bz", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} @@ -549,7 +550,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() @override_config( @@ -584,7 +585,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() @override_config( @@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_password.assert_not_called() @@ -646,13 +647,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) tok1 = channel.json_body["access_token"] channel = self._send_login( "test.login_type", "localuser", test_field="", device_id="dev2" ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # make the initial request which returns a 401 channel = self._delete_device(tok1, "dev2") @@ -721,7 +722,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # password login shouldn't work and should be rejected with a 400 # ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) def test_on_logged_out(self): """Tests that the on_logged_out callback is called when the user logs out.""" @@ -884,7 +885,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): }, access_token=tok, ) - self.assertEqual(channel.code, 403, channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.THREEPID_DENIED, @@ -906,7 +907,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): }, access_token=tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertIn("sid", channel.json_body) m.assert_called_once_with("email", "bar@test.com", registration) @@ -949,12 +950,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "register", {"auth": {"session": session, "type": LoginType.DUMMY}}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["flows"] def _send_password_login(self, user: str, password: str) -> FakeChannel: diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py new file mode 100644 index 0000000000..254e7e4b80 --- /dev/null +++ b/tests/handlers/test_room_member.py @@ -0,0 +1,290 @@ +from http import HTTPStatus +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +import synapse.rest.client.login +import synapse.rest.client.room +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import LimitExceededError +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import FrozenEventV3 +from synapse.federation.federation_client import SendJoinResult +from synapse.server import HomeServer +from synapse.types import UserID, create_requester +from synapse.util import Clock + +from tests.replication._base import RedisMultiWorkerStreamTestCase +from tests.server import make_request +from tests.test_utils import make_awaitable +from tests.unittest import FederatingHomeserverTestCase, override_config + + +class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + + # Create three users. + self.alice = self.register_user("alice", "pass") + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_token = self.login("bob", "pass") + self.chris = self.register_user("chris", "pass") + self.chris_token = self.login("chris", "pass") + + # Create a room on this homeserver. Note that this counts as a join: it + # contributes to the rate limter's count of actions + self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + + self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}" + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. + # Try joining the room as Bob. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.bob), + target=UserID.from_string(self.bob), + room_id=self.room_id, + action=Membership.JOIN, + ) + ) + + # The rate limiter bucket is full. A second join should be denied. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. Alice should still be able to change her displayname. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.alice), + target=UserID.from_string(self.alice), + room_id=self.room_id, + action=Membership.JOIN, + content={"displayname": "Alice Cooper"}, + ) + ) + + # Still room in the limiter bucket. Chris's join should be accepted. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ) + ) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}}) + def test_remote_joins_contribute_to_rate_limit(self) -> None: + # Join once, to fill the rate limiter bucket. + # + # To do this we have to mock the responses from the remote homeserver. + # We also patch out a bunch of event checks on our end. All we're really + # trying to check here is that remote joins will bump the rate limter when + # they are persisted. + create_event_source = { + "auth_events": [], + "content": { + "creator": f"@creator:{self.OTHER_SERVER_NAME}", + "room_version": self.hs.config.server.default_room_version.identifier, + }, + "depth": 0, + "origin_server_ts": 0, + "prev_events": [], + "room_id": self.intially_unjoined_room_id, + "sender": f"@creator:{self.OTHER_SERVER_NAME}", + "state_key": "", + "type": EventTypes.Create, + } + self.add_hashes_and_signatures_from_other_server( + create_event_source, + self.hs.config.server.default_room_version, + ) + create_event = FrozenEventV3( + create_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + join_event_source = { + "auth_events": [create_event.event_id], + "content": {"membership": "join"}, + "depth": 1, + "origin_server_ts": 100, + "prev_events": [create_event.event_id], + "sender": self.bob, + "state_key": self.bob, + "room_id": self.intially_unjoined_room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + self.hs.config.server.default_room_version, + join_event_source, + self.hs.hostname, + self.hs.signing_key, + ) + join_event = FrozenEventV3( + join_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + self.OTHER_SERVER_NAME, + join_event, + self.hs.config.server.default_room_version, + ) + ) + ) + mock_send_join = Mock( + return_value=make_awaitable( + SendJoinResult( + join_event, + self.OTHER_SERVER_NAME, + state=[create_event], + auth_chain=[create_event], + partial_state=False, + servers_in_room=[], + ) + ) + ) + + with patch.object( + self.handler.federation_handler.federation_client, + "make_membership_event", + mock_make_membership_event, + ), patch.object( + self.handler.federation_handler.federation_client, + "send_join", + mock_send_join, + ), patch( + "synapse.event_auth._is_membership_change_allowed", + return_value=None, + ), patch( + "synapse.handlers.federation_event.check_state_dependent_auth_rules", + return_value=None, + ): + self.get_success( + self.handler.update_membership( + requester=create_requester(self.bob), + target=UserID.from_string(self.bob), + room_id=self.intially_unjoined_room_id, + action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], + ) + ) + + # Try to join as Chris. Should get denied. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.intially_unjoined_room_id, + action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], + ), + LimitExceededError, + ) + + # TODO: test that remote joins to a room are rate limited. + # Could do this by setting the burst count to 1, then: + # - remote-joining a room + # - immediately leaving + # - trying to remote-join again. + + +class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + + # Create three users. + self.alice = self.register_user("alice", "pass") + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_token = self.login("bob", "pass") + self.chris = self.register_user("chris", "pass") + self.chris_token = self.login("chris", "pass") + + # Create a room on this homeserver. + # Note that this counts as a + self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + self.intially_unjoined_room_id = "!example:otherhs" + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_users_joining_on_another_worker_contribute_to_rate_limit( + self, + ) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. + self.replicate() + + # Spawn another worker and have bob join via it. + worker_app = self.make_worker_hs( + "synapse.app.generic_worker", extra_config={"worker_name": "other worker"} + ) + worker_site = self._hs_to_site[worker_app] + channel = make_request( + self.reactor, + worker_site, + "POST", + f"/_matrix/client/v3/rooms/{self.room_id}/join", + access_token=self.bob_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # wait for join to arrive over replication + self.replicate() + + # Try to join as Chris on the worker. Should get denied because Alice + # and Bob have both joined the room. + self.get_failure( + worker_app.get_room_member_handler().update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) + + # Try to join as Chris on the original worker. Should get denied because Alice + # and Bob have both joined the room. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index ecc7cc6461..e3f38fbcc5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,7 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() # The rooms should be excluded from the sync response. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 230dc76f72..2526136ff8 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -21,7 +21,7 @@ from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room @@ -1130,6 +1130,8 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("guest_access", r) self.assertIn("history_visibility", r) self.assertIn("state_events", r) + self.assertIn("room_type", r) + self.assertIsNone(r["room_type"]) # Check that the correct number of total rooms was returned self.assertEqual(channel.json_body["total_rooms"], total_rooms) @@ -1229,7 +1231,11 @@ class RoomTestCase(unittest.HomeserverTestCase): def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id = self.helper.create_room_as( + self.admin_user, + tok=self.admin_user_tok, + extra_content={"creation_content": {"type": RoomTypes.SPACE}}, + ) test_alias = "#test:test" test_room_name = "something" @@ -1306,6 +1312,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id, r["room_id"]) self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) + self.assertEqual(RoomTypes.SPACE, r["room_type"]) def test_room_list_sort_order(self) -> None: """Test room list sort ordering. alphabetical name versus number of members, @@ -1630,7 +1637,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("guest_access", channel.json_body) self.assertIn("history_visibility", channel.json_body) self.assertIn("state_events", channel.json_body) - + self.assertIn("room_type", channel.json_body) self.assertEqual(room_id_1, channel.json_body["room_id"]) def test_single_room_devices(self) -> None: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e32aaadb98..12db68d564 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "admin": False}, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # Admin user is not blocked by mau anymore - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1636,6 +1636,41 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(len(pushers), 0) + @override_config( + { + "email": { + "enable_notifs": True, + "notif_for_new_users": True, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) + def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None: + """ + Check that a new regular user is created successfully when they have a msisdn + threepid and email notif_for_new_users is set to True. + """ + url = self.url_prefix % "@bob:test" + + # Create user + body = { + "password": "abc123", + "threepids": [{"medium": "msisdn", "address": "1234567890"}], + } + + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body, + ) + + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"]) + def test_set_password(self) -> None: """ Test setting a new password for another user. @@ -2372,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123"}, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 1f9b65351e..7ae926dc9c 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import re from email.parser import Parser +from http import HTTPStatus from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock @@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") - ) - self.assertEqual(channel.code, 403, channel.result) + channel = self.make_request("POST", "/_matrix/client/r0/login", body) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) def test_basic_password_reset(self) -> None: """Test basic password reset flow""" @@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): new_password: str, session_id: str, client_secret: str, - expected_code: int = 200, + expected_code: int = HTTPStatus.OK, ) -> None: channel = self.make_request( "POST", @@ -479,16 +477,14 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(memberships[0].room_id, room_id, memberships) def deactivate(self, user_id: str, tok: str) -> None: - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) @@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self) -> None: @@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self) -> None: @@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) @@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_no_valid_token(self) -> None: @@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) @@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/good/site", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="file:///host/path", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) @@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link=None, - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.com/some/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.org/some/also/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://bad.example.org/some/bad/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": []}) @@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) def _request_token( @@ -948,7 +952,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): email: str, client_secret: str, next_link: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, ) -> Optional[str]: """Request a validation token to add an email address to a user's account @@ -993,7 +997,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) @@ -1002,7 +1008,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -1052,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1061,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1092,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that not providing any MXID raises an error.""" self._test_status( users=None, - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.MISSING_PARAM, ) @@ -1100,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.INVALID_PARAM, ) @@ -1286,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): def _test_status( self, users: Optional[List[str]], - expected_status_code: int = 200, + expected_status_code: int = HTTPStatus.OK, expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index 16e7ef41bc..7a88aa2cda 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -97,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # We use deliberately a localpart under the length threshold so # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -110,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Check with an alias of allowed length. There should already be # a test that ensures it works in test_register.py, but let's be # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(5)} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -144,8 +141,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Add an alias for the room, as the appservice alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string() - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", @@ -193,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.hs.hostname, ) - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) + request_data = {"aliases": [self.random_alias(alias_length)]} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -206,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) -> str: alias = self.random_alias(alias_length) url = "/_matrix/client/r0/directory/room/%s" % alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 299b9d21e2..dc17c9d113 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -51,12 +50,11 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] - params = { + request_data = { "id_server": "testis", "medium": "email", "address": "test@example.com", } - request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") channel = self.make_request( b"POST", request_url, request_data, access_token=tok diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f6efa5fe37..a2958f6959 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import time import urllib.parse +from http import HTTPStatus from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -261,20 +261,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -296,7 +296,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], False) @@ -307,7 +307,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -330,7 +330,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: @@ -341,14 +341,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -367,14 +367,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -399,7 +399,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "/_matrix/client/v3/login", - json.dumps(body).encode("utf8"), + body, custom_headers=None, ) @@ -466,7 +466,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) expected_flow_types = [ "m.login.cas", @@ -494,14 +494,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # parse the form to check it has fields assumed elsewhere in this class html = channel.result["body"].decode("utf-8") @@ -530,7 +530,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + "&idp=cas", shorthand=False, ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers cas_uri = location_headers[0] @@ -555,7 +555,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers saml_uri = location_headers[0] @@ -579,7 +579,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -606,7 +606,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) content_type_headers = channel.headers.getRawHeaders("Content-Type") assert content_type_headers self.assertTrue(content_type_headers[-1].startswith("text/html")) @@ -634,7 +634,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") def test_multi_sso_redirect_to_unknown(self) -> None: @@ -643,18 +643,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") - self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -765,7 +765,7 @@ class CASTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": @@ -1246,7 +1246,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ) # that should redirect to the username picker - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers picker_url = location_headers[0] @@ -1290,7 +1290,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ("Content-Length", str(len(content))), ], ) - self.assertEqual(chan.code, 302, chan.result) + self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1300,7 +1300,7 @@ class UsernamePickerTestCase(HomeserverTestCase): path=location_headers[0], custom_headers=[("Cookie", "username_mapping_session=" + session_id)], ) - self.assertEqual(chan.code, 302, chan.result) + self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1325,5 +1325,5 @@ class UsernamePickerTestCase(HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3a74d2e96c..e19d21d6ee 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_too_short(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request_data = {"username": "kermit", "password": "shorty"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_digit(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request_data = {"username": "kermit", "password": "longerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_symbol(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request_data = {"username": "kermit", "password": "l0ngerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_uppercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request_data = {"username": "kermit", "password": "l0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_lowercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_compliant(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request_data = {"username": "kermit", "password": "L0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has @@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", compliant_password) tok = self.login("kermit", compliant_password) - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) + request_data = { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } channel = self.make_request( "POST", "/_matrix/client/r0/account/password", diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index afb08b2736..071b488cc0 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime -import json import os from typing import Any, Dict, List, Tuple @@ -62,9 +61,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = { + "username": "as_user_kermit", + "type": APP_SERVICE_REGISTRATION_TYPE, + } channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -85,7 +85,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) + request_data = {"username": "as_user_kermit"} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -95,9 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) @@ -105,14 +103,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self) -> None: - request_data = json.dumps({"username": "kermit", "password": 666}) + request_data = {"username": "kermit", "password": 666} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: - request_data = json.dumps({"username": 777, "password": "monkey"}) + request_data = {"username": 777, "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) @@ -121,13 +119,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" - params = { + request_data = { "username": "kermit", "password": "monkey", "device_id": device_id, "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) det_data = { @@ -140,7 +137,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) channel = self.make_request(b"POST", self.url, request_data) @@ -188,13 +185,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: for i in range(0, 6): - params = { + request_data = { "username": "kermit" + str(i), "password": "monkey", "device_id": "frogfone", "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) if i == 5: @@ -234,7 +230,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } # Request without auth to get flows and session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one @@ -251,8 +247,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -262,8 +257,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.DUMMY, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) det_data = { "user_id": f"@{username}:{self.hs.hostname}", "home_server": self.hs.hostname, @@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Test with token param missing (invalid) @@ -298,21 +292,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.REGISTRATION_TOKEN, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with session1 and check `pending` is 1 @@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Repeat request to make sure pending isn't increased again - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) pending = self.get_success( store.db_pool.simple_select_one_onecol( "registration_tokens", @@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session2, } - channel = self.make_request(b"POST", self.url, json.dumps(params2)) + channel = self.make_request(b"POST", self.url, params2) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 res = self.get_success( store.db_pool.simple_select_one( @@ -386,7 +380,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 - channel = self.make_request(b"POST", self.url, json.dumps(params2)) + channel = self.make_request(b"POST", self.url, params2) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Check authentication fails with expired token @@ -420,7 +414,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Check authentication succeeds - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do 2 requests without auth to get two session IDs params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with both sessions @@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) params2["auth"] = { "type": LoginType.REGISTRATION_TOKEN, "token": token, "session": session2, } - self.make_request(b"POST", self.url, json.dumps(params2)) + self.make_request(b"POST", self.url, params2) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check `result` of registration token stage for session1 is `True` result1 = self.get_success( @@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do request without auth to get a session ID params: JsonDict = {"username": "kermit", "password": "monkey"} - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Use token @@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - self.make_request(b"POST", self.url, json.dumps(params)) + self.make_request(b"POST", self.url, params) # Delete token self.get_success( @@ -592,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "require_at_registration": True, }, "account_threepid_delegates": { - "email": "https://id_server", "msisdn": "https://id_server", }, + "email": {"notif_from": "Synapse <synapse@example.com>"}, } ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: @@ -827,8 +821,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) + request_data = {"user_id": user_id} channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -845,12 +838,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -870,12 +862,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index 20a259fc43..ad0d0209f7 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -77,10 +75,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, + "POST", self.report_path, data, access_token=self.other_user_tok ) self.assertEqual( response_status, int(channel.result["code"]), msg=channel.result["body"] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index df7ffbe545..c45cb32090 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,6 +18,7 @@ """Tests REST events for /rooms paths.""" import json +from http import HTTPStatus from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse @@ -104,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -112,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -134,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' @@ -165,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -194,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -220,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room: str, members: Iterable = frozenset(), expect_code: int = 200 @@ -309,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) def test_joined_permissions(self) -> None: @@ -342,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of other, expect 403 @@ -351,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of self, expect 200 @@ -371,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( @@ -379,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # It is always valid to LEAVE if you've already left (currently.) @@ -388,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember @@ -405,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.BAN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -415,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to invite: Must never happen. @@ -424,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.INVITE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -434,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase): src=other, targ=other, membership=Membership.JOIN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -444,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to knock: Must never happen. @@ -453,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.KNOCK, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -463,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -473,7 +474,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.LEAVE, - expect_code=200, + expect_code=HTTPStatus.OK, ) @@ -493,7 +494,7 @@ class RoomStateTestCase(RoomBase): "/rooms/%s/state" % room_id, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertCountEqual( [state_event["type"] for state_event in channel.json_body], { @@ -516,7 +517,7 @@ class RoomStateTestCase(RoomBase): "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(channel.json_body, {"membership": "join"}) @@ -530,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self) -> None: """ @@ -550,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -559,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self) -> None: """ @@ -572,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ @@ -593,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -612,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" @@ -628,17 +629,17 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/members` request.""" @@ -651,7 +652,7 @@ class RoomsMemberListTestCase(RoomBase): "/rooms/%s/members" % room_id, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { @@ -671,7 +672,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] channel = make_request_with_cancellation_test( @@ -682,7 +683,7 @@ class RoomsMemberListTestCase(RoomBase): "/rooms/%s/members?at=%s" % (room_id, sync_token), ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { @@ -706,10 +707,10 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(37, channel.resource_usage.db_txn_count) + self.assertEqual(44, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -719,21 +720,21 @@ class RoomsCreateTestCase(RoomBase): b'{"initial_state":[{"type": "m.bridge", "content": {}}]}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(41, channel.resource_usage.db_txn_count) + self.assertEqual(50, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self) -> None: @@ -741,16 +742,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -758,7 +759,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self) -> None: @@ -769,20 +770,18 @@ class RoomsCreateTestCase(RoomBase): # Build the request's content. We use local MXIDs because invites over federation # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") + content = { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } # Test that the invites are correctly ratelimited. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) self.assertEqual( "Cannot invite so many users at once", channel.json_body["error"], @@ -795,7 +794,7 @@ class RoomsCreateTestCase(RoomBase): # Test that the invites aren't ratelimited anymore. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed @@ -819,7 +818,7 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) @@ -845,7 +844,7 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) @@ -865,7 +864,7 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) @@ -882,54 +881,68 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) - self.assertEqual(404, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -945,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -969,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( @@ -980,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) @@ -998,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self) -> None: @@ -1017,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -1137,7 +1164,9 @@ class RoomJoinTestCase(RoomBase): # Now make the callback deny all room joins, and check that a join actually fails. return_value = False - self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + self.helper.join( + self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2 + ) def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called @@ -1205,7 +1234,7 @@ class RoomJoinTestCase(RoomBase): self.helper.join( self.room3, self.user2, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, expect_errcode=return_value, tok=self.tok2, ) @@ -1216,7 +1245,7 @@ class RoomJoinTestCase(RoomBase): self.helper.join( self.room3, self.user2, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, expect_errcode=return_value[0], tok=self.tok2, expect_additional_fields=return_value[1], @@ -1270,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1335,71 +1364,93 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @parameterized.expand( [ # Allow param( - name="NOT_SPAM", value="NOT_SPAM", expected_code=200, expected_fields={} + name="NOT_SPAM", + value="NOT_SPAM", + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + param( + name="False", + value=False, + expected_code=HTTPStatus.OK, + expected_fields={}, ), - param(name="False", value=False, expected_code=200, expected_fields={}), # Block param( name="scalene string", value="ANY OTHER STRING", - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={"errcode": "M_FORBIDDEN"}, ), param( name="True", value=True, - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={"errcode": "M_FORBIDDEN"}, ), param( name="Code", value=Codes.LIMIT_EXCEEDED, - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={"errcode": "M_LIMIT_EXCEEDED"}, ), param( name="Tuple", value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}), - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={ "errcode": "M_SERVER_NOT_TRUSTED", "additional_field": "12345", @@ -1584,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_normal_user_can_not_post_state_event(self) -> None: # Given I am a normal member of a room @@ -1598,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed because state events require PL>=50 - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " "user_level (0) < send_level (50)", @@ -1625,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1653,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1681,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (1)", @@ -1712,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): # Then I am not allowed because the public_chat config does not # affect this room, because this room is a private_chat - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", @@ -1731,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(self.room_id, channel.json_body["room_id"]) self.assertEqual("join", channel.json_body["membership"]) @@ -1774,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1785,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1824,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) @@ -1852,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) @@ -1869,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) @@ -1997,14 +2048,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): @@ -2123,7 +2174,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", @@ -2140,7 +2191,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] - HttpResponseException(404, "Not Found", b""), + HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), make_awaitable({}), ) @@ -2152,7 +2203,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ @@ -2198,21 +2249,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) + request_data = {"displayname": self.displayname} channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self) -> None: - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) + request_data = {"membership": "join", "displayname": "other test user"} channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" @@ -2220,7 +2269,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -2228,7 +2277,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) @@ -2262,7 +2311,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2276,7 +2325,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2290,7 +2339,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2304,7 +2353,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2316,7 +2365,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2328,7 +2377,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2347,7 +2396,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2359,7 +2408,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ), access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_content = channel.json_body @@ -2407,7 +2456,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2437,7 +2486,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2472,7 +2521,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2552,16 +2601,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2589,16 +2636,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2638,16 +2683,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2820,7 +2863,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["chunk"] @@ -2925,7 +2968,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2991,7 +3034,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=invited_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -3092,8 +3135,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -3122,8 +3164,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -3149,7 +3190,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), + content, access_token=self.room_owner_tok, ) self.assertEqual(channel.code, expected_code, channel.result) @@ -3283,7 +3324,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), @@ -3344,7 +3385,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index e3efd1f1b0..b085c50356 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", f"/rooms/{self.room_id}/read_markers", - body, + {ReceiptTypes.READ: res["event_id"]}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 5eb0f243f7..9a48e9286f 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room @@ -113,14 +112,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth( - origin: str, - event: EventBase, - context: EventContext, - *args: Any, - **kwargs: Any, - ) -> EventContext: - return context + async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: + pass hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 93f749744d..105d418698 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -136,7 +136,7 @@ class RestHelper: self.site, "POST", path, - json.dumps(content).encode("utf8"), + content, custom_headers=custom_headers, ) @@ -210,7 +210,7 @@ class RestHelper: self.site, "POST", path, - json.dumps(data).encode("utf8"), + data, ) assert ( @@ -309,7 +309,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(data).encode("utf8"), + data, ) assert ( @@ -392,7 +392,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(content or {}).encode("utf8"), + content or {}, custom_headers=custom_headers, ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 79727c430f..d18fc13c21 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -126,7 +126,9 @@ class _TestImage: expected_scaled: The expected bytes from scaled thumbnailing, or None if test should just check for a valid image returned. expected_found: True if the file should exist on the server, or False if - a 404 is expected. + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. """ data: bytes @@ -135,6 +137,7 @@ class _TestImage: expected_cropped: Optional[bytes] = None expected_scaled: Optional[bytes] = None expected_found: bool = True + unable_to_thumbnail: bool = False @parameterized_class( @@ -192,6 +195,7 @@ class _TestImage: b"image/gif", b".gif", expected_found=False, + unable_to_thumbnail=True, ), ), ], @@ -366,18 +370,29 @@ class MediaRepoTests(unittest.HomeserverTestCase): def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( - "crop", self.test_image.expected_cropped, self.test_image.expected_found + "crop", + self.test_image.expected_cropped, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_thumbnail_scale(self) -> None: """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_invalid_type(self) -> None: """An invalid thumbnail type is never available.""" - self._test_thumbnail("invalid", None, False) + self._test_thumbnail( + "invalid", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} @@ -386,7 +401,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only scaled thumbnails, but request a cropped one. """ - self._test_thumbnail("crop", None, False) + self._test_thumbnail( + "crop", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} @@ -395,14 +415,22 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only cropped thumbnails, but request a scaled one. """ - self._test_thumbnail("scale", None, False) + self._test_thumbnail( + "scale", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) def test_thumbnail_repeated_thumbnail(self) -> None: """Test that fetching the same thumbnail works, and deleting the on disk thumbnail regenerates it. """ self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) if not self.test_image.expected_found: @@ -459,8 +487,24 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) def _test_thumbnail( - self, method: str, expected_body: Optional[bytes], expected_found: bool + self, + method: str, + expected_body: Optional[bytes], + expected_found: bool, + unable_to_thumbnail: bool = False, ) -> None: + """Test the given thumbnailing method works as expected. + + Args: + method: The thumbnailing method to use (crop, scale). + expected_body: The expected bytes from thumbnailing, or None if + test should just check for a valid image. + expected_found: True if the file should exist on the server, or False if + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. + """ + params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, @@ -496,6 +540,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): else: # ensure that the result is at least some valid image Image.open(BytesIO(channel.result["body"])) + elif unable_to_thumbnail: + # A 400 with a JSON body. + self.assertEqual(channel.code, 400) + self.assertEqual( + channel.json_body, + { + "errcode": "M_UNKNOWN", + "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", + }, + ) else: # A 404 with a JSON body. self.assertEqual(channel.code, 404) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 38963ce4a7..46d829b062 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) def test_simple(self): """Test that we cache events that we pull from the DB.""" @@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): """ # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: # We keep hold of the event event though we never use it. @@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: self.get_success(self.store.get_event(self.event_id)) @@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) @contextmanager def blocking_get_event_calls( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index e8c53f16d9..ba40124c8a 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock - from twisted.test.proto_helpers import MemoryReactor +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.util import Clock @@ -24,15 +24,14 @@ from tests.unittest import HomeserverTestCase USER_ID = "@user:example.com" -PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] -HIGHLIGHT = [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, -] - class EventPushActionsStoreTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main persist_events_store = hs.get_datastores().persist_events @@ -54,154 +53,118 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) def test_count_aggregation(self) -> None: - room_id = "!foo:example.com" - user_id = "@user1235:test" + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) - last_read_stream_ordering = [0] + last_event_id: str - def _assert_counts(noitf_count: int, highlight_count: int) -> None: + def _assert_counts( + noitf_count: int, unread_count: int, highlight_count: int + ) -> None: counts = self.get_success( self.store.db_pool.runInteraction( - "", - self.store._get_unread_counts_by_pos_txn, + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, room_id, user_id, - last_read_stream_ordering[0], ) ) self.assertEqual( counts, NotifCounts( notify_count=noitf_count, - unread_count=0, # Unread counts are tested in the sync tests. + unread_count=unread_count, highlight_count=highlight_count, ), ) - def _inject_actions(stream: int, action: list) -> None: - event = Mock() - event.room_id = room_id - event.event_id = f"$test{stream}:example.com" - event.internal_metadata.stream_ordering = stream - event.internal_metadata.is_outlier.return_value = False - event.depth = stream - - self.store._events_stream_cache.entity_has_changed(room_id, stream) - - self.get_success( - self.store.db_pool.simple_insert( - table="events", - values={ - "stream_ordering": stream, - "topological_ordering": stream, - "type": "m.room.message", - "room_id": room_id, - "processed": True, - "outlier": False, - "event_id": event.event_id, - }, - ) + def _create_event(highlight: bool = False) -> str: + result = self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": user_id if highlight else "msg"}, + tok=other_token, ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id - self.get_success( - self.store.add_push_actions_to_staging( - event.event_id, - {user_id: action}, - False, - ) - ) - self.get_success( - self.store.db_pool.runInteraction( - "", - self.persist_events_store._set_push_actions_for_event_and_users_txn, - [(event, None)], - [(event, None)], - ) - ) - - def _rotate(stream: int) -> None: - self.get_success( - self.store.db_pool.runInteraction( - "rotate-receipts", self.store._handle_new_receipts_for_notifs_txn - ) - ) - - self.get_success( - self.store.db_pool.runInteraction( - "rotate-notifs", self.store._rotate_notifs_before_txn, stream - ) - ) - - def _mark_read(stream: int, depth: int) -> None: - last_read_stream_ordering[0] = stream + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + def _mark_read(event_id: str) -> None: self.get_success( self.store.insert_receipt( room_id, "m.read", user_id=user_id, - event_ids=[f"$test{stream}:example.com"], + event_ids=[event_id], data={}, ) ) - _assert_counts(0, 0) - _inject_actions(1, PlAIN_NOTIF) - _assert_counts(1, 0) - _rotate(1) - _assert_counts(1, 0) + _assert_counts(0, 0, 0) + _create_event() + _assert_counts(1, 1, 0) + _rotate() + _assert_counts(1, 1, 0) - _inject_actions(3, PlAIN_NOTIF) - _assert_counts(2, 0) - _rotate(3) - _assert_counts(2, 0) + event_id = _create_event() + _assert_counts(2, 2, 0) + _rotate() + _assert_counts(2, 2, 0) - _inject_actions(5, PlAIN_NOTIF) - _mark_read(3, 3) - _assert_counts(1, 0) + _create_event() + _mark_read(event_id) + _assert_counts(1, 1, 0) - _mark_read(5, 5) - _assert_counts(0, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) - _inject_actions(6, PlAIN_NOTIF) - _rotate(6) - _assert_counts(1, 0) - - self.get_success( - self.store.db_pool.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" - ) - ) + _create_event() + _rotate() + _assert_counts(1, 1, 0) - _assert_counts(1, 0) + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 1, 0) - _mark_read(6, 6) - _assert_counts(0, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) - _inject_actions(8, HIGHLIGHT) - _assert_counts(1, 1) - _rotate(8) - _assert_counts(1, 1) + event_id = _create_event(True) + _assert_counts(1, 1, 1) + _rotate() + _assert_counts(1, 1, 1) # Check that adding another notification and rotating after highlight # works. - _inject_actions(10, PlAIN_NOTIF) - _rotate(10) - _assert_counts(2, 1) + _create_event() + _rotate() + _assert_counts(2, 2, 1) # Check that sending read receipts at different points results in the # right counts. - _mark_read(8, 8) - _assert_counts(1, 0) - _mark_read(10, 10) - _assert_counts(0, 0) - - _inject_actions(11, HIGHLIGHT) - _assert_counts(1, 1) - _mark_read(11, 11) - _assert_counts(0, 0) - _rotate(11) - _assert_counts(0, 0) + _mark_read(event_id) + _assert_counts(1, 1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) + _rotate() + _assert_counts(0, 0, 0) def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 8dfaa0559b..9c1182ed16 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase): ) # The events aren't found. - self.store._invalidate_get_event_cache(create_event.event_id) + self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 1218786d79..240b02cb9f 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -23,7 +23,6 @@ from synapse.util import Clock from tests import unittest from tests.server import TestHomeServer -from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -110,60 +109,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) - def test_get_joined_users_from_context(self) -> None: - room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = self.get_success( - event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN - ) - ) - - # first, create a regular event - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, - ) - ) - - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - - # Regression test for #7376: create a state event whose key matches bob's - # user_id, but which is *not* a membership event, and persist that; then check - # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = self.get_success( - event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, - ) - ) - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, - ) - ) - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 371cd201af..e42d7b9ba0 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -19,7 +19,7 @@ from parameterized import parameterized from synapse import event_auth from synapse.api.constants import EventContentFields -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -689,6 +689,45 @@ class EventAuthTestCase(unittest.TestCase): auth_events.values(), ) + def test_room_v10_rejects_string_power_levels(self) -> None: + pl_event_content = {"users_default": "42"} + pl_event = make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + pl_event2_content = {"events": {"m.room.name": "42", "m.room.power_levels": 42}} + pl_event2 = make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event2_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + with self.assertRaises(SynapseError): + event_auth._check_power_levels( + pl_event.room_version, pl_event, {("fake_type", "fake_key"): pl_event2} + ) + + with self.assertRaises(SynapseError): + event_auth._check_power_levels( + pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event} + ) + # helpers for making events TEST_DOMAIN = "example.com" diff --git a/tests/test_federation.py b/tests/test_federation.py index 0cbef70bfa..779fad1f63 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -81,12 +81,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.handler = self.homeserver.get_federation_handler() federation_event_handler = self.homeserver.get_federation_event_handler() - async def _check_event_auth( - origin, - event, - context, - ): - return context + async def _check_event_auth(origin, event, context): + pass federation_event_handler._check_event_auth = _check_event_auth self.client = self.homeserver.get_federation_client() diff --git a/tests/test_server.py b/tests/test_server.py index fc4bce899c..2fe4411401 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -231,7 +231,7 @@ class OptionsResourceTests(unittest.TestCase): parse_listener_def({"type": "http", "port": 0}), self.resource, "1.0", - max_request_body_size=1234, + max_request_body_size=4096, reactor=self.reactor, ) diff --git a/tests/test_state.py b/tests/test_state.py index 6ca8d8f21d..bafd6d1750 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.state import StateHandler, StateResolutionHandler +from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry from synapse.util import Clock from synapse.util.macaroons import MacaroonGenerator @@ -99,6 +99,10 @@ class _DummyStore: state_group = self._next_group self._next_group += 1 + if current_state_ids is None: + current_state_ids = dict(self._group_to_state[prev_group]) + current_state_ids.update(delta_ids) + self._group_to_state[state_group] = dict(current_state_ids) return state_group @@ -760,3 +764,43 @@ class StateTestCase(unittest.TestCase): result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result + + def test_make_state_cache_entry(self): + "Test that calculating a prev_group and delta is correct" + + new_state = { + ("a", ""): "E", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + } + + # old_state_1 has fewer differences to new_state than old_state_2, but + # the delta involves deleting a key, which isn't allowed in the deltas, + # so we should pick old_state_2 as the prev_group. + + # `old_state_1` has two differences: `a` and `e` + old_state_1 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + ("e", ""): "E", + } + + # `old_state_2` has three differences: `a`, `c` and `d` + old_state_2 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "F", + ("d", ""): "F", + } + + entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2}) + + self.assertEqual(entry.prev_group, 2) + + # There are three changes from `old_state_2` to `new_state` + self.assertEqual( + entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"} + ) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 37fada5c53..d3c13cf14c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactorClock @@ -51,7 +50,7 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"401", channel.result) @@ -82,16 +81,14 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(channel.json_body["params"], expected_params) # We have to complete the dummy auth stage before completing the terms stage - request_data = json.dumps( - { - "username": "kermit", - "password": "monkey", - "auth": { - "session": channel.json_body["session"], - "type": "m.login.dummy", - }, - } - ) + request_data = { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.dummy", + }, + } self.registration_handler.check_username = Mock(return_value=True) @@ -102,16 +99,14 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) # Finish the UI auth for terms - request_data = json.dumps( - { - "username": "kermit", - "password": "monkey", - "auth": { - "session": channel.json_body["session"], - "type": "m.login.terms", - }, - } - ) + request_data = { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.terms", + }, + } channel = self.make_request(b"POST", self.url, request_data) # We're interested in getting a response that looks like a successful diff --git a/tests/test_visibility.py b/tests/test_visibility.py index f338af6c36..c385b2f8d4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -272,7 +272,7 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): "state_key": "@user:test", "content": {"membership": "invite"}, } - self.add_hashes_and_signatures(invite_pdu) + self.add_hashes_and_signatures_from_other_server(invite_pdu) invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index c645dd3563..66ce92f4a6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -16,7 +16,6 @@ import gc import hashlib import hmac -import json import logging import secrets import time @@ -285,7 +284,7 @@ class HomeserverTestCase(TestCase): config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", - max_request_body_size=1234, + max_request_body_size=4096, reactor=self.reactor, ) @@ -619,20 +618,16 @@ class HomeserverTestCase(TestCase): want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac_digest = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": username, - "displayname": displayname, - "password": password, - "admin": admin, - "mac": want_mac_digest, - "inhibit_login": True, - } - ) - channel = self.make_request( - "POST", "/_synapse/admin/v1/register", body.encode("utf8") - ) + body = { + "nonce": nonce, + "username": username, + "displayname": displayname, + "password": password, + "admin": admin, + "mac": want_mac_digest, + "inhibit_login": True, + } + channel = self.make_request("POST", "/_synapse/admin/v1/register", body) self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] @@ -676,9 +671,7 @@ class HomeserverTestCase(TestCase): custom_headers: Optional[Iterable[CustomHeaderType]] = None, ) -> str: """ - Log in a user, and get an access token. Requires the Login API be - registered. - + Log in a user, and get an access token. Requires the Login API be registered. """ body = {"type": "m.login.password", "user": username, "password": password} if device_id: @@ -687,7 +680,7 @@ class HomeserverTestCase(TestCase): channel = self.make_request( "POST", "/_matrix/client/r0/login", - json.dumps(body).encode("utf8"), + body, custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) @@ -780,7 +773,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id, FetchKeyResult( verify_key=verify_key, - valid_until_ts=clock.time_msec() + 1000, + valid_until_ts=clock.time_msec() + 10000, ), ) ], @@ -838,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): client_ip=client_ip, ) - def add_hashes_and_signatures( + def add_hashes_and_signatures_from_other_server( self, event_dict: JsonDict, room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], diff --git a/tests/utils.py b/tests/utils.py index 424cc4c2a0..d2c6d1e852 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,6 +167,7 @@ def default_config( "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_joins_per_room": {"per_second": 10000, "burst_count": 10000}, "rc_invites": { "per_room": {"per_second": 10000, "burst_count": 10000}, "per_user": {"per_second": 10000, "burst_count": 10000}, |