diff options
author | H. Shay <hillerys@element.io> | 2023-03-06 12:48:59 -0800 |
---|---|---|
committer | H. Shay <hillerys@element.io> | 2023-03-06 12:48:59 -0800 |
commit | 7fc487421f19d8841c36481701655bc31bfcd79f (patch) | |
tree | e5c4e634fa957a561675156c868c561c36a71d66 | |
parent | add clearer return values (diff) | |
parent | Pass the requester during event serialization. (#15174) (diff) | |
download | synapse-7fc487421f19d8841c36481701655bc31bfcd79f.tar.xz |
Merge branch 'develop' into shay/rework_module
346 files changed, 4916 insertions, 3242 deletions
diff --git a/.ci/scripts/calculate_jobs.py b/.ci/scripts/calculate_jobs.py index 0cdc20e19c..b41ec0b6e2 100755 --- a/.ci/scripts/calculate_jobs.py +++ b/.ci/scripts/calculate_jobs.py @@ -109,12 +109,27 @@ sytest_tests = [ "postgres": "multi-postgres", "workers": "workers", }, + { + "sytest-tag": "focal", + "postgres": "multi-postgres", + "workers": "workers", + "reactor": "asyncio", + }, ] if not IS_PR: sytest_tests.extend( [ { + "sytest-tag": "focal", + "reactor": "asyncio", + }, + { + "sytest-tag": "focal", + "postgres": "postgres", + "reactor": "asyncio", + }, + { "sytest-tag": "testing", "postgres": "postgres", }, diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index c3638c35eb..839b895c82 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -21,4 +21,8 @@ aff1eb7c671b0a3813407321d2702ec46c71fa56 0a00b7ff14890987f09112a2ae696c61001e6cf1 # Convert tests/rest/admin/test_room.py to unix file endings (#7953). -c4268e3da64f1abb5b31deaeb5769adb6510c0a7 \ No newline at end of file +c4268e3da64f1abb5b31deaeb5769adb6510c0a7 + +# Update black to 23.1.0 (#15103) +9bb2eac71962970d02842bca441f4bcdbbf93a11 + diff --git a/.github/workflows/docs-pr-netlify.yaml b/.github/workflows/docs-pr-netlify.yaml index ef7a38144e..a5e74eb297 100644 --- a/.github/workflows/docs-pr-netlify.yaml +++ b/.github/workflows/docs-pr-netlify.yaml @@ -14,7 +14,7 @@ jobs: # There's a 'download artifact' action, but it hasn't been updated for the workflow_run action # (https://github.com/actions/download-artifact/issues/60) so instead we get this mess: - name: 📥 Download artifact - uses: dawidd6/action-download-artifact@bd10f381a96414ce2b13a11bfa89902ba7cea07f # v2.24.3 + uses: dawidd6/action-download-artifact@5e780fc7bbd0cac69fc73271ed86edf5dcb72d67 # v2.26.0 with: workflow: docs-pr.yaml run_id: ${{ github.event.workflow_run.id }} diff --git a/.github/workflows/docs-pr.yaml b/.github/workflows/docs-pr.yaml index d41f6c4490..6634f2644e 100644 --- a/.github/workflows/docs-pr.yaml +++ b/.github/workflows/docs-pr.yaml @@ -12,7 +12,7 @@ jobs: name: GitHub Pages runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Setup mdbook uses: peaceiris/actions-mdbook@adeb05db28a0c0004681db83893d56c0388ea9ea # v1.2.0 @@ -39,7 +39,7 @@ jobs: name: Check links in documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Setup mdbook uses: peaceiris/actions-mdbook@adeb05db28a0c0004681db83893d56c0388ea9ea # v1.2.0 diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index 8485daf87f..6da7c22e4c 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -27,7 +27,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -61,7 +61,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -134,7 +134,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: stable - uses: Swatinem/rust-cache@v2 diff --git a/.github/workflows/push_complement_image.yml b/.github/workflows/push_complement_image.yml index f26143de6b..b76c4cb323 100644 --- a/.github/workflows/push_complement_image.yml +++ b/.github/workflows/push_complement_image.yml @@ -48,7 +48,7 @@ jobs: with: ref: master - name: Login to registry - uses: docker/login-action@v1 + uses: docker/login-action@v2 with: registry: ghcr.io username: ${{ github.actor }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 94f7f2657c..806bd2bfa4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -112,7 +112,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: 1.58.1 components: clippy @@ -134,7 +134,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: nightly-2022-12-01 components: clippy @@ -154,9 +154,10 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: - toolchain: 1.58.1 + # We use nightly so that it correctly groups together imports + toolchain: nightly-2022-12-01 components: rustfmt - uses: Swatinem/rust-cache@v2 @@ -221,7 +222,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -266,7 +267,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -368,6 +369,7 @@ jobs: SYTEST_BRANCH: ${{ github.head_ref }} POSTGRES: ${{ matrix.job.postgres && 1}} MULTI_POSTGRES: ${{ (matrix.job.postgres == 'multi-postgres') && 1}} + ASYNCIO_REACTOR: ${{ (matrix.job.reactor == 'asyncio') && 1 }} WORKERS: ${{ matrix.job.workers && 1 }} BLACKLIST: ${{ matrix.job.workers && 'synapse-blacklist-with-workers' }} TOP: ${{ github.workspace }} @@ -386,7 +388,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -531,7 +533,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -562,7 +564,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: 1.58.1 - uses: Swatinem/rust-cache@v2 @@ -585,7 +587,7 @@ jobs: # There don't seem to be versioned releases of this action per se: for each rust # version there is a branch which gets constantly rebased on top of master. # We pin to a specific commit for paranoia's sake. - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: nightly-2022-12-01 - uses: Swatinem/rust-cache@v2 diff --git a/.github/workflows/triage-incoming.yml b/.github/workflows/triage-incoming.yml index 0f0397cf5b..24dac47bf2 100644 --- a/.github/workflows/triage-incoming.yml +++ b/.github/workflows/triage-incoming.yml @@ -6,7 +6,7 @@ on: jobs: triage: - uses: matrix-org/backend-meta/.github/workflows/triage-incoming.yml@v1 + uses: matrix-org/backend-meta/.github/workflows/triage-incoming.yml@v2 with: project_id: 'PVT_kwDOAIB0Bs4AFDdZ' content_id: ${{ github.event.issue.node_id }} diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 5654d2f3e2..db514571c4 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -43,7 +43,7 @@ jobs: - run: sudo apt-get -qq install xmlsec1 - name: Install Rust - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -82,7 +82,7 @@ jobs: - uses: actions/checkout@v3 - name: Install Rust - uses: dtolnay/rust-toolchain@25dc93b901a87e864900a8aec6c12e9aa794c0c3 + uses: dtolnay/rust-toolchain@e12eda571dc9a5ee5d58eecf4738ec291c66f295 with: toolchain: stable - uses: Swatinem/rust-cache@v2 diff --git a/CHANGES.md b/CHANGES.md index a62bd4eb28..644ef6e036 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,100 @@ +Synapse 1.78.0 (2023-02-28) +=========================== + +Bugfixes +-------- + +- Fix a bug introduced in Synapse 1.76 where 5s delays would occasionally occur in deployments using workers. ([\#15150](https://github.com/matrix-org/synapse/issues/15150)) + + +Synapse 1.78.0rc1 (2023-02-21) +============================== + +Features +-------- + +- Implement the experimental `exact_event_match` push rule condition from [MSC3758](https://github.com/matrix-org/matrix-spec-proposals/pull/3758). ([\#14964](https://github.com/matrix-org/synapse/issues/14964)) +- Add account data to the command line [user data export tool](https://matrix-org.github.io/synapse/v1.78/usage/administration/admin_faq.html#how-can-i-export-user-data). ([\#14969](https://github.com/matrix-org/synapse/issues/14969)) +- Implement [MSC3873](https://github.com/matrix-org/matrix-spec-proposals/pull/3873) to disambiguate push rule keys with dots in them. ([\#15004](https://github.com/matrix-org/synapse/issues/15004)) +- Allow Synapse to use a specific Redis [logical database](https://redis.io/commands/select/) in worker-mode deployments. ([\#15034](https://github.com/matrix-org/synapse/issues/15034)) +- Tag opentracing spans for federation requests with the name of the worker serving the request. ([\#15042](https://github.com/matrix-org/synapse/issues/15042)) +- Implement the experimental `exact_event_property_contains` push rule condition from [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966). ([\#15045](https://github.com/matrix-org/synapse/issues/15045)) +- Remove spurious `dont_notify` action from the defaults for the `.m.rule.reaction` pushrule. ([\#15073](https://github.com/matrix-org/synapse/issues/15073)) +- Update the error code returned when user sends a duplicate annotation. ([\#15075](https://github.com/matrix-org/synapse/issues/15075)) + + +Bugfixes +-------- + +- Prevent clients from reporting nonexistent events. ([\#13779](https://github.com/matrix-org/synapse/issues/13779)) +- Return spec-compliant JSON errors when unknown endpoints are requested. ([\#14605](https://github.com/matrix-org/synapse/issues/14605)) +- Fix a long-standing bug where the room aliases returned could be corrupted. ([\#15038](https://github.com/matrix-org/synapse/issues/15038)) +- Fix a bug introduced in Synapse 1.76.0 where partially-joined rooms could not be deleted using the [purge room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api). ([\#15068](https://github.com/matrix-org/synapse/issues/15068)) +- Fix a long-standing bug where federated joins would fail if the first server in the list of servers to try is not in the room. ([\#15074](https://github.com/matrix-org/synapse/issues/15074)) +- Fix a bug introduced in Synapse v1.74.0 where searching with colons when using ICU for search term tokenisation would fail with an error. ([\#15079](https://github.com/matrix-org/synapse/issues/15079)) +- Reduce the likelihood of a rare race condition where rejoining a restricted room over federation would fail. ([\#15080](https://github.com/matrix-org/synapse/issues/15080)) +- Fix a bug introduced in Synapse 1.76 where workers would fail to start if the `health` listener was configured. ([\#15096](https://github.com/matrix-org/synapse/issues/15096)) +- Fix a bug introduced in Synapse 1.75 where the [portdb script](https://matrix-org.github.io/synapse/release-v1.78/postgres.html#porting-from-sqlite) would fail to run after a room had been faster-joined. ([\#15108](https://github.com/matrix-org/synapse/issues/15108)) + + +Improved Documentation +---------------------- + +- Document how to start Synapse with Poetry. Contributed by @thezaidbintariq. ([\#14892](https://github.com/matrix-org/synapse/issues/14892), [\#15022](https://github.com/matrix-org/synapse/issues/15022)) +- Update delegation documentation to clarify that SRV DNS delegation does not eliminate all needs to serve files from .well-known locations. Contributed by @williamkray. ([\#14959](https://github.com/matrix-org/synapse/issues/14959)) +- Fix a mistake in registration_shared_secret_path docs. ([\#15078](https://github.com/matrix-org/synapse/issues/15078)) +- Refer to a more recent blog post on the [Database Maintenance Tools](https://matrix-org.github.io/synapse/latest/usage/administration/database_maintenance_tools.html) page. Contributed by @jahway603. ([\#15083](https://github.com/matrix-org/synapse/issues/15083)) + + +Internal Changes +---------------- + +- Re-type hint some collections as read-only. ([\#13755](https://github.com/matrix-org/synapse/issues/13755)) +- Faster joins: don't stall when another user joins during a partial-state room resync. ([\#14606](https://github.com/matrix-org/synapse/issues/14606)) +- Add a class `UnpersistedEventContext` to allow for the batching up of storing state groups. ([\#14675](https://github.com/matrix-org/synapse/issues/14675)) +- Add a check to ensure that locked dependencies have source distributions available. ([\#14742](https://github.com/matrix-org/synapse/issues/14742)) +- Tweak comment on `_is_local_room_accessible` as part of room visibility in `/hierarchy` to clarify the condition for a room being visible. ([\#14834](https://github.com/matrix-org/synapse/issues/14834)) +- Prevent `WARNING: there is already a transaction in progress` lines appearing in PostgreSQL's logs on some occasions. ([\#14840](https://github.com/matrix-org/synapse/issues/14840)) +- Use `StrCollection` to avoid potential bugs with `Collection[str]`. ([\#14929](https://github.com/matrix-org/synapse/issues/14929)) +- Improve performance of `/sync` in a few situations. ([\#14973](https://github.com/matrix-org/synapse/issues/14973)) +- Limit concurrent event creation for a room to avoid state resolution when sending bursts of events to a local room. ([\#14977](https://github.com/matrix-org/synapse/issues/14977)) +- Skip calculating unread push actions in /sync when enable_push is false. ([\#14980](https://github.com/matrix-org/synapse/issues/14980)) +- Add a schema dump symlinks inside `contrib`, to make it easier for IDEs to interrogate Synapse's database schema. ([\#14982](https://github.com/matrix-org/synapse/issues/14982)) +- Improve type hints. ([\#15008](https://github.com/matrix-org/synapse/issues/15008), [\#15026](https://github.com/matrix-org/synapse/issues/15026), [\#15027](https://github.com/matrix-org/synapse/issues/15027), [\#15028](https://github.com/matrix-org/synapse/issues/15028), [\#15031](https://github.com/matrix-org/synapse/issues/15031), [\#15035](https://github.com/matrix-org/synapse/issues/15035), [\#15052](https://github.com/matrix-org/synapse/issues/15052), [\#15072](https://github.com/matrix-org/synapse/issues/15072), [\#15084](https://github.com/matrix-org/synapse/issues/15084)) +- Update [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952) support based on changes to the MSC. ([\#15037](https://github.com/matrix-org/synapse/issues/15037)) +- Avoid mutating a cached value in `get_user_devices_from_cache`. ([\#15040](https://github.com/matrix-org/synapse/issues/15040)) +- Fix a rare exception in logs on start up. ([\#15041](https://github.com/matrix-org/synapse/issues/15041)) +- Update pyo3-log to v0.8.1. ([\#15043](https://github.com/matrix-org/synapse/issues/15043)) +- Avoid mutating cached values in `_generate_sync_entry_for_account_data`. ([\#15047](https://github.com/matrix-org/synapse/issues/15047)) +- Refactor arguments of `try_unbind_threepid` and `_try_unbind_threepid_with_id_server` to not use dictionaries. ([\#15053](https://github.com/matrix-org/synapse/issues/15053)) +- Merge debug logging from the hotfixes branch. ([\#15054](https://github.com/matrix-org/synapse/issues/15054)) +- Faster joins: omit device list updates originating from partial state rooms in /sync responses without lazy loading of members enabled. ([\#15069](https://github.com/matrix-org/synapse/issues/15069)) +- Fix clashing database transaction name. ([\#15070](https://github.com/matrix-org/synapse/issues/15070)) +- Upper-bound frozendict dependency. This works around us being unable to test installing our wheels against Python 3.11 in CI. ([\#15114](https://github.com/matrix-org/synapse/issues/15114)) +- Tweak logging for when a worker waits for its view of a replication stream to catch up. ([\#15120](https://github.com/matrix-org/synapse/issues/15120)) + +<details><summary>Locked dependency updates</summary> + +- Bump bleach from 5.0.1 to 6.0.0. ([\#15059](https://github.com/matrix-org/synapse/issues/15059)) +- Bump cryptography from 38.0.4 to 39.0.1. ([\#15020](https://github.com/matrix-org/synapse/issues/15020)) +- Bump ruff version from 0.0.230 to 0.0.237. ([\#15033](https://github.com/matrix-org/synapse/issues/15033)) +- Bump dtolnay/rust-toolchain from 9cd00a88a73addc8617065438eff914dd08d0955 to 25dc93b901a87e864900a8aec6c12e9aa794c0c3. ([\#15060](https://github.com/matrix-org/synapse/issues/15060)) +- Bump systemd-python from 234 to 235. ([\#15061](https://github.com/matrix-org/synapse/issues/15061)) +- Bump serde_json from 1.0.92 to 1.0.93. ([\#15062](https://github.com/matrix-org/synapse/issues/15062)) +- Bump types-requests from 2.28.11.8 to 2.28.11.12. ([\#15063](https://github.com/matrix-org/synapse/issues/15063)) +- Bump types-pillow from 9.4.0.5 to 9.4.0.10. ([\#15064](https://github.com/matrix-org/synapse/issues/15064)) +- Bump sentry-sdk from 1.13.0 to 1.15.0. ([\#15065](https://github.com/matrix-org/synapse/issues/15065)) +- Bump types-jsonschema from 4.17.0.3 to 4.17.0.5. ([\#15099](https://github.com/matrix-org/synapse/issues/15099)) +- Bump types-bleach from 5.0.3.1 to 6.0.0.0. ([\#15100](https://github.com/matrix-org/synapse/issues/15100)) +- Bump dtolnay/rust-toolchain from 25dc93b901a87e864900a8aec6c12e9aa794c0c3 to e12eda571dc9a5ee5d58eecf4738ec291c66f295. ([\#15101](https://github.com/matrix-org/synapse/issues/15101)) +- Bump dawidd6/action-download-artifact from 2.24.3 to 2.25.0. ([\#15102](https://github.com/matrix-org/synapse/issues/15102)) +- Bump types-pillow from 9.4.0.10 to 9.4.0.13. ([\#15104](https://github.com/matrix-org/synapse/issues/15104)) +- Bump types-setuptools from 67.1.0.0 to 67.3.0.1. ([\#15105](https://github.com/matrix-org/synapse/issues/15105)) + + +</details> + + Synapse 1.77.0 (2023-02-14) =========================== @@ -63,7 +160,7 @@ Internal Changes - Preparatory work for adding a denormalised event stream ordering column in the future. Contributed by Nick @ Beeper (@fizzadar). ([\#14979](https://github.com/matrix-org/synapse/issues/14979), [9cd7610](https://github.com/matrix-org/synapse/commit/9cd7610f86ab5051c9365dd38d1eec405a5f8ca6), [f10caa7](https://github.com/matrix-org/synapse/commit/f10caa73eee0caa91cf373966104d1ededae2aee); see [\#15014](https://github.com/matrix-org/synapse/issues/15014)) - Add tests for `_flatten_dict`. ([\#14981](https://github.com/matrix-org/synapse/issues/14981), [\#15002](https://github.com/matrix-org/synapse/issues/15002)) -<details><summary>Dependabot updates</summary> +<details><summary>Locked dependency updates</summary> - Bump dtolnay/rust-toolchain from e645b0cf01249a964ec099494d38d2da0f0b349f to 9cd00a88a73addc8617065438eff914dd08d0955. ([\#14968](https://github.com/matrix-org/synapse/issues/14968)) - Bump docker/build-push-action from 3 to 4. ([\#14952](https://github.com/matrix-org/synapse/issues/14952)) diff --git a/Cargo.lock b/Cargo.lock index 1bf76cb863..f858b2107f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,9 +343,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.93" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" +checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" dependencies = [ "itoa", "ryu", diff --git a/changelog.d/13755.misc b/changelog.d/13755.misc deleted file mode 100644 index 662ee00e99..0000000000 --- a/changelog.d/13755.misc +++ /dev/null @@ -1 +0,0 @@ -Re-type hint some collections as read-only. diff --git a/changelog.d/13779.bugfix b/changelog.d/13779.bugfix deleted file mode 100644 index a92c722c6e..0000000000 --- a/changelog.d/13779.bugfix +++ /dev/null @@ -1 +0,0 @@ -Prevent clients from reporting nonexistent events. \ No newline at end of file diff --git a/changelog.d/14026.doc b/changelog.d/14026.doc new file mode 100644 index 0000000000..28fc5568ea --- /dev/null +++ b/changelog.d/14026.doc @@ -0,0 +1 @@ +Document how to use caches in a module. diff --git a/changelog.d/14101.misc b/changelog.d/14101.misc new file mode 100644 index 0000000000..c48f40cd38 --- /dev/null +++ b/changelog.d/14101.misc @@ -0,0 +1 @@ +Run the integration test suites with the asyncio reactor enabled in CI. diff --git a/changelog.d/14605.bugfix b/changelog.d/14605.bugfix deleted file mode 100644 index cb95a87d92..0000000000 --- a/changelog.d/14605.bugfix +++ /dev/null @@ -1 +0,0 @@ -Return spec-compliant JSON errors when unknown endpoints are requested. diff --git a/changelog.d/14606.misc b/changelog.d/14606.misc deleted file mode 100644 index e2debc96d8..0000000000 --- a/changelog.d/14606.misc +++ /dev/null @@ -1 +0,0 @@ -Faster joins: don't stall when another user joins during a fast join resync. diff --git a/changelog.d/14675.misc b/changelog.d/14675.misc deleted file mode 100644 index bc1ac1c82a..0000000000 --- a/changelog.d/14675.misc +++ /dev/null @@ -1 +0,0 @@ -Add a class UnpersistedEventContext to allow for the batching up of storing state groups. diff --git a/changelog.d/14742.misc b/changelog.d/14742.misc deleted file mode 100644 index c0b5d2c062..0000000000 --- a/changelog.d/14742.misc +++ /dev/null @@ -1 +0,0 @@ -Add check to ensure locked dependencies have source distributions available. \ No newline at end of file diff --git a/changelog.d/14834.misc b/changelog.d/14834.misc deleted file mode 100644 index e683212dc4..0000000000 --- a/changelog.d/14834.misc +++ /dev/null @@ -1 +0,0 @@ -Tweak comment on `_is_local_room_accessible` as part of room visibility in `/hierarchy` to clarify the condition for a room being visible. \ No newline at end of file diff --git a/changelog.d/14840.misc b/changelog.d/14840.misc deleted file mode 100644 index ff6084284a..0000000000 --- a/changelog.d/14840.misc +++ /dev/null @@ -1 +0,0 @@ -Prevent "WARNING: there is already a transaction in progress" lines appearing in PostgreSQL's logs on some occasions. \ No newline at end of file diff --git a/changelog.d/14869.bugfix b/changelog.d/14869.bugfix new file mode 100644 index 0000000000..865b597741 --- /dev/null +++ b/changelog.d/14869.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.75.0rc1 that caused experimental support for deleting account data to raise an internal server error while using an account data writer worker. \ No newline at end of file diff --git a/changelog.d/14918.misc b/changelog.d/14918.misc new file mode 100644 index 0000000000..828794354a --- /dev/null +++ b/changelog.d/14918.misc @@ -0,0 +1 @@ +Batch up storing state groups when creating a new room. \ No newline at end of file diff --git a/changelog.d/14929.misc b/changelog.d/14929.misc deleted file mode 100644 index 2cc3614dfd..0000000000 --- a/changelog.d/14929.misc +++ /dev/null @@ -1 +0,0 @@ -Use `StrCollection` to avoid potential bugs with `Collection[str]`. diff --git a/changelog.d/14959.doc b/changelog.d/14959.doc deleted file mode 100644 index 45edf1a765..0000000000 --- a/changelog.d/14959.doc +++ /dev/null @@ -1 +0,0 @@ -Update delegation documentation to clarify that SRV DNS delegation does not eliminate all needs to serve files from .well-known locations. Contributed by @williamkray. diff --git a/changelog.d/14964.feature b/changelog.d/14964.feature deleted file mode 100644 index 13c0bc193b..0000000000 --- a/changelog.d/14964.feature +++ /dev/null @@ -1 +0,0 @@ -Implement the experimental `exact_event_match` push rule condition from [MSC3758](https://github.com/matrix-org/matrix-spec-proposals/pull/3758). diff --git a/changelog.d/14973.misc b/changelog.d/14973.misc deleted file mode 100644 index 3657623602..0000000000 --- a/changelog.d/14973.misc +++ /dev/null @@ -1 +0,0 @@ -Improve performance of `/sync` in a few situations. diff --git a/changelog.d/14977.misc b/changelog.d/14977.misc deleted file mode 100644 index 4d551c52b7..0000000000 --- a/changelog.d/14977.misc +++ /dev/null @@ -1 +0,0 @@ -Limit concurrent event creation for a room to avoid state resolution when sending bursts of events to a local room. \ No newline at end of file diff --git a/changelog.d/14980.misc b/changelog.d/14980.misc deleted file mode 100644 index 145f4a788b..0000000000 --- a/changelog.d/14980.misc +++ /dev/null @@ -1 +0,0 @@ -Skip calculating unread push actions in /sync when enable_push is false. diff --git a/changelog.d/14982.misc b/changelog.d/14982.misc deleted file mode 100644 index 9aaa7ce264..0000000000 --- a/changelog.d/14982.misc +++ /dev/null @@ -1 +0,0 @@ -Add a schema dump symlinks inside `contrib`, to make it easier for IDEs to interrogate Synapse's database schema. diff --git a/changelog.d/15004.feature b/changelog.d/15004.feature deleted file mode 100644 index d11d0aca91..0000000000 --- a/changelog.d/15004.feature +++ /dev/null @@ -1 +0,0 @@ -Implement [MSC3873](https://github.com/matrix-org/matrix-spec-proposals/pull/3873) to unambiguate push rule keys with dots in them. diff --git a/changelog.d/15020.misc b/changelog.d/15020.misc deleted file mode 100644 index c5290283f0..0000000000 --- a/changelog.d/15020.misc +++ /dev/null @@ -1 +0,0 @@ -Bump cryptography from 38.0.4 to 39.0.1. diff --git a/changelog.d/15022.doc b/changelog.d/15022.doc deleted file mode 100644 index e1627c20cb..0000000000 --- a/changelog.d/15022.doc +++ /dev/null @@ -1 +0,0 @@ -Document how to start Synapse in the contributing guide. diff --git a/changelog.d/15026.misc b/changelog.d/15026.misc deleted file mode 100644 index 93ceaeafc9..0000000000 --- a/changelog.d/15026.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/15027.misc b/changelog.d/15027.misc deleted file mode 100644 index 93ceaeafc9..0000000000 --- a/changelog.d/15027.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/15028.misc b/changelog.d/15028.misc deleted file mode 100644 index 93ceaeafc9..0000000000 --- a/changelog.d/15028.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/15031.misc b/changelog.d/15031.misc deleted file mode 100644 index 93ceaeafc9..0000000000 --- a/changelog.d/15031.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/15033.misc b/changelog.d/15033.misc deleted file mode 100644 index 83dc3a75b6..0000000000 --- a/changelog.d/15033.misc +++ /dev/null @@ -1 +0,0 @@ -Bump ruff version from 0.0.230 to 0.0.237. diff --git a/changelog.d/15034.feature b/changelog.d/15034.feature deleted file mode 100644 index 34f320da92..0000000000 --- a/changelog.d/15034.feature +++ /dev/null @@ -1 +0,0 @@ -Allow Synapse to use a specific Redis [logical database](https://redis.io/commands/select/) in worker-mode deployments. diff --git a/changelog.d/15035.misc b/changelog.d/15035.misc deleted file mode 100644 index 93ceaeafc9..0000000000 --- a/changelog.d/15035.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/15038.bugfix b/changelog.d/15038.bugfix deleted file mode 100644 index 4695a09756..0000000000 --- a/changelog.d/15038.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where the room aliases returned could be corrupted. diff --git a/changelog.d/15040.misc b/changelog.d/15040.misc deleted file mode 100644 index ca129b64af..0000000000 --- a/changelog.d/15040.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid mutating a cached value in `get_user_devices_from_cache`. diff --git a/changelog.d/15041.misc b/changelog.d/15041.misc deleted file mode 100644 index d602b0043a..0000000000 --- a/changelog.d/15041.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a rare exception in logs on start up. diff --git a/changelog.d/15042.feature b/changelog.d/15042.feature deleted file mode 100644 index 7a4de89f00..0000000000 --- a/changelog.d/15042.feature +++ /dev/null @@ -1 +0,0 @@ -Tag opentracing spans for federation requests with the name of the worker serving the request. diff --git a/changelog.d/15043.misc b/changelog.d/15043.misc deleted file mode 100644 index cb18394123..0000000000 --- a/changelog.d/15043.misc +++ /dev/null @@ -1 +0,0 @@ -Update pyo3-log to v0.8.1. diff --git a/changelog.d/15044.feature b/changelog.d/15044.feature new file mode 100644 index 0000000000..91e5cda8c3 --- /dev/null +++ b/changelog.d/15044.feature @@ -0,0 +1 @@ +Add two new Third Party Rules module API callbacks: [`on_add_user_third_party_identifier`](https://matrix-org.github.io/synapse/v1.79/modules/third_party_rules_callbacks.html#on_add_user_third_party_identifier) and [`on_remove_user_third_party_identifier`](https://matrix-org.github.io/synapse/v1.79/modules/third_party_rules_callbacks.html#on_remove_user_third_party_identifier). \ No newline at end of file diff --git a/changelog.d/15045.feature b/changelog.d/15045.feature deleted file mode 100644 index 87766befda..0000000000 --- a/changelog.d/15045.feature +++ /dev/null @@ -1 +0,0 @@ -Experimental support for [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966): the `exact_event_property_contains` push rule condition. diff --git a/changelog.d/15047.misc b/changelog.d/15047.misc deleted file mode 100644 index 561dc874de..0000000000 --- a/changelog.d/15047.misc +++ /dev/null @@ -1 +0,0 @@ -Avoid mutating cached values in `_generate_sync_entry_for_account_data`. diff --git a/changelog.d/15051.misc b/changelog.d/15051.misc new file mode 100644 index 0000000000..fabfe77d35 --- /dev/null +++ b/changelog.d/15051.misc @@ -0,0 +1 @@ +Update [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952) support based on changes to the MSC. diff --git a/changelog.d/15053.misc b/changelog.d/15053.misc deleted file mode 100644 index c27528f5c6..0000000000 --- a/changelog.d/15053.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor arguments of `try_unbind_threepid` and `_try_unbind_threepid_with_id_server` to not use dictionaries. \ No newline at end of file diff --git a/changelog.d/15054.misc b/changelog.d/15054.misc deleted file mode 100644 index d800b107cf..0000000000 --- a/changelog.d/15054.misc +++ /dev/null @@ -1 +0,0 @@ -Merge debug logging from the hotfixes branch. diff --git a/changelog.d/15059.misc b/changelog.d/15059.misc deleted file mode 100644 index e962b208fd..0000000000 --- a/changelog.d/15059.misc +++ /dev/null @@ -1 +0,0 @@ -Bump bleach from 5.0.1 to 6.0.0. diff --git a/changelog.d/15060.misc b/changelog.d/15060.misc deleted file mode 100644 index 5b99e06003..0000000000 --- a/changelog.d/15060.misc +++ /dev/null @@ -1 +0,0 @@ -Bump dtolnay/rust-toolchain from 9cd00a88a73addc8617065438eff914dd08d0955 to 25dc93b901a87e864900a8aec6c12e9aa794c0c3. diff --git a/changelog.d/15061.misc b/changelog.d/15061.misc deleted file mode 100644 index 40017827a2..0000000000 --- a/changelog.d/15061.misc +++ /dev/null @@ -1 +0,0 @@ -Bump systemd-python from 234 to 235. diff --git a/changelog.d/15062.misc b/changelog.d/15062.misc deleted file mode 100644 index adc1940630..0000000000 --- a/changelog.d/15062.misc +++ /dev/null @@ -1 +0,0 @@ -Bump serde_json from 1.0.92 to 1.0.93. diff --git a/changelog.d/15063.misc b/changelog.d/15063.misc deleted file mode 100644 index b52e1faed0..0000000000 --- a/changelog.d/15063.misc +++ /dev/null @@ -1 +0,0 @@ -Bump types-requests from 2.28.11.8 to 2.28.11.12. diff --git a/changelog.d/15064.misc b/changelog.d/15064.misc deleted file mode 100644 index 644d4bb230..0000000000 --- a/changelog.d/15064.misc +++ /dev/null @@ -1 +0,0 @@ -Bump types-pillow from 9.4.0.5 to 9.4.0.10. diff --git a/changelog.d/15065.misc b/changelog.d/15065.misc deleted file mode 100644 index df2f9a773e..0000000000 --- a/changelog.d/15065.misc +++ /dev/null @@ -1 +0,0 @@ -Bump sentry-sdk from 1.13.0 to 1.15.0. diff --git a/changelog.d/15068.bugfix b/changelog.d/15068.bugfix deleted file mode 100644 index f09ffa2877..0000000000 --- a/changelog.d/15068.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse 1.76.0 where partially-joined rooms could not be deleted using the [purge room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api). diff --git a/changelog.d/15069.misc b/changelog.d/15069.misc deleted file mode 100644 index e7a619ad2b..0000000000 --- a/changelog.d/15069.misc +++ /dev/null @@ -1 +0,0 @@ -Faster joins: omit device list updates originating from partial state rooms in /sync responses without lazy loading of members enabled. diff --git a/changelog.d/15070.misc b/changelog.d/15070.misc deleted file mode 100644 index 0f3244de9f..0000000000 --- a/changelog.d/15070.misc +++ /dev/null @@ -1 +0,0 @@ -Fix clashing database transaction name. diff --git a/changelog.d/15071.doc b/changelog.d/15071.doc new file mode 100644 index 0000000000..7fbaba3e8c --- /dev/null +++ b/changelog.d/15071.doc @@ -0,0 +1 @@ +Clarify which worker processes the ThirdPartyRules' [`on_new_event`](https://matrix-org.github.io/synapse/v1.78/modules/third_party_rules_callbacks.html#on_new_event) module API callback runs on. \ No newline at end of file diff --git a/changelog.d/15072.misc b/changelog.d/15072.misc deleted file mode 100644 index 93ceaeafc9..0000000000 --- a/changelog.d/15072.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints. diff --git a/changelog.d/15073.feature b/changelog.d/15073.feature deleted file mode 100644 index 2889e3444f..0000000000 --- a/changelog.d/15073.feature +++ /dev/null @@ -1 +0,0 @@ -Remove spurious `dont_notify` action from the defaults for the `.m.rule.reaction` pushrule. diff --git a/changelog.d/15074.bugfix b/changelog.d/15074.bugfix deleted file mode 100644 index d1ceb4f4c8..0000000000 --- a/changelog.d/15074.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where federated joins would fail if the first server in the list of servers to try is not in the room. diff --git a/changelog.d/15075.feature b/changelog.d/15075.feature deleted file mode 100644 index d25a7567a4..0000000000 --- a/changelog.d/15075.feature +++ /dev/null @@ -1,2 +0,0 @@ -Update the error code returned when user sends a duplicate annotation. - diff --git a/changelog.d/15077.feature b/changelog.d/15077.feature new file mode 100644 index 0000000000..384e751056 --- /dev/null +++ b/changelog.d/15077.feature @@ -0,0 +1 @@ +Experimental support for MSC3967 to not require UIA for setting up cross-signing on first use. diff --git a/changelog.d/15088.bugfix b/changelog.d/15088.bugfix new file mode 100644 index 0000000000..15d5286f80 --- /dev/null +++ b/changelog.d/15088.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse handled an unspecced field on push rules. diff --git a/changelog.d/15092.bugfix b/changelog.d/15092.bugfix new file mode 100644 index 0000000000..67509c5c69 --- /dev/null +++ b/changelog.d/15092.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a URL preview would break if the discovered oEmbed failed to download. diff --git a/changelog.d/15093.bugfix b/changelog.d/15093.bugfix new file mode 100644 index 0000000000..00f1c19391 --- /dev/null +++ b/changelog.d/15093.bugfix @@ -0,0 +1 @@ +Remove the unspecced `room_alias` field from the [`/createRoom`](https://spec.matrix.org/v1.6/client-server-api/#post_matrixclientv3createroom) response. diff --git a/changelog.d/15095.misc b/changelog.d/15095.misc new file mode 100644 index 0000000000..a2fafe2fff --- /dev/null +++ b/changelog.d/15095.misc @@ -0,0 +1 @@ +Refactor writing json data in `FileExfiltrationWriter`. \ No newline at end of file diff --git a/changelog.d/15103.misc b/changelog.d/15103.misc new file mode 100644 index 0000000000..65322498c9 --- /dev/null +++ b/changelog.d/15103.misc @@ -0,0 +1 @@ +Bump black from 22.12.0 to 23.1.0. diff --git a/changelog.d/15107.feature b/changelog.d/15107.feature new file mode 100644 index 0000000000..2bdb6a29fc --- /dev/null +++ b/changelog.d/15107.feature @@ -0,0 +1 @@ +Add media information to the command line [user data export tool](https://matrix-org.github.io/synapse/v1.79/usage/administration/admin_faq.html#how-can-i-export-user-data). \ No newline at end of file diff --git a/changelog.d/15112.doc b/changelog.d/15112.doc new file mode 100644 index 0000000000..7dec43a50b --- /dev/null +++ b/changelog.d/15112.doc @@ -0,0 +1 @@ +Document using [Shibboleth](https://www.shibboleth.net/) as an OpenID Provider. diff --git a/changelog.d/15116.feature b/changelog.d/15116.feature new file mode 100644 index 0000000000..087d8dc7f1 --- /dev/null +++ b/changelog.d/15116.feature @@ -0,0 +1 @@ +Add an [admin API](https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html) to delete a [specific event report](https://spec.matrix.org/v1.6/client-server-api/#reporting-content). \ No newline at end of file diff --git a/changelog.d/15133.feature b/changelog.d/15133.feature new file mode 100644 index 0000000000..e0af0d4554 --- /dev/null +++ b/changelog.d/15133.feature @@ -0,0 +1 @@ +Add support for knocking to workers. \ No newline at end of file diff --git a/changelog.d/15134.feature b/changelog.d/15134.feature new file mode 100644 index 0000000000..0dbb30bc8f --- /dev/null +++ b/changelog.d/15134.feature @@ -0,0 +1 @@ +Allow use of the `/filter` Client-Server APIs on workers. \ No newline at end of file diff --git a/changelog.d/15135.misc b/changelog.d/15135.misc new file mode 100644 index 0000000000..25c4dbffe1 --- /dev/null +++ b/changelog.d/15135.misc @@ -0,0 +1 @@ +Tighten the login ratelimit defaults. diff --git a/changelog.d/15137.removal b/changelog.d/15137.removal new file mode 100644 index 0000000000..c533b0c9dd --- /dev/null +++ b/changelog.d/15137.removal @@ -0,0 +1 @@ +Remove the undocumented and unspecced `type` parameter to the `/thumbnail` endpoint. diff --git a/changelog.d/15138.misc b/changelog.d/15138.misc new file mode 100644 index 0000000000..fb706b27f2 --- /dev/null +++ b/changelog.d/15138.misc @@ -0,0 +1 @@ +Fix a typo in an experimental config setting. diff --git a/changelog.d/15139.doc b/changelog.d/15139.doc new file mode 100644 index 0000000000..d8ab48b272 --- /dev/null +++ b/changelog.d/15139.doc @@ -0,0 +1 @@ +Correct reference to `federation_verify_certificates` in configuration documentation. diff --git a/changelog.d/15143.misc b/changelog.d/15143.misc new file mode 100644 index 0000000000..cff4518811 --- /dev/null +++ b/changelog.d/15143.misc @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory search was not case-insensitive for accented characters. diff --git a/changelog.d/15146.misc b/changelog.d/15146.misc new file mode 100644 index 0000000000..8de5f95239 --- /dev/null +++ b/changelog.d/15146.misc @@ -0,0 +1 @@ +Refactor the media modules. diff --git a/changelog.d/15148.doc b/changelog.d/15148.doc new file mode 100644 index 0000000000..4e9e163306 --- /dev/null +++ b/changelog.d/15148.doc @@ -0,0 +1 @@ +Correct small documentation errors in some `MatrixFederationHttpClient` methods. \ No newline at end of file diff --git a/changelog.d/15152.misc b/changelog.d/15152.misc new file mode 100644 index 0000000000..6b2c73d0ab --- /dev/null +++ b/changelog.d/15152.misc @@ -0,0 +1 @@ +Bump dawidd6/action-download-artifact from 2.25.0 to 2.26.0. diff --git a/changelog.d/15154.misc b/changelog.d/15154.misc new file mode 100644 index 0000000000..c958b52078 --- /dev/null +++ b/changelog.d/15154.misc @@ -0,0 +1 @@ +Bump docker/login-action from 1 to 2. diff --git a/changelog.d/15155.misc b/changelog.d/15155.misc new file mode 100644 index 0000000000..40c73e96ec --- /dev/null +++ b/changelog.d/15155.misc @@ -0,0 +1 @@ +Bump actions/checkout from 2 to 3. diff --git a/changelog.d/15156.misc b/changelog.d/15156.misc new file mode 100644 index 0000000000..ebae4cb456 --- /dev/null +++ b/changelog.d/15156.misc @@ -0,0 +1 @@ +Bump matrix-org/backend-meta from 1 to 2. diff --git a/changelog.d/15157.misc b/changelog.d/15157.misc new file mode 100644 index 0000000000..730b706dfe --- /dev/null +++ b/changelog.d/15157.misc @@ -0,0 +1 @@ +Bump typing-extensions from 4.4.0 to 4.5.0. diff --git a/changelog.d/15158.misc b/changelog.d/15158.misc new file mode 100644 index 0000000000..fc0eecfd21 --- /dev/null +++ b/changelog.d/15158.misc @@ -0,0 +1 @@ +Bump types-opentracing from 2.4.10.1 to 2.4.10.3. diff --git a/changelog.d/15159.misc b/changelog.d/15159.misc new file mode 100644 index 0000000000..ebb857a89e --- /dev/null +++ b/changelog.d/15159.misc @@ -0,0 +1 @@ +Bump ruff from 0.0.237 to 0.0.252. diff --git a/changelog.d/15160.misc b/changelog.d/15160.misc new file mode 100644 index 0000000000..13b098d17c --- /dev/null +++ b/changelog.d/15160.misc @@ -0,0 +1 @@ +Bump types-setuptools from 67.3.0.1 to 67.4.0.3. diff --git a/changelog.d/15163.bugfix b/changelog.d/15163.bugfix new file mode 100644 index 0000000000..7ff1cd4463 --- /dev/null +++ b/changelog.d/15163.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where an initial sync would not respond to changes to the list of ignored users if there was an initial sync cached. \ No newline at end of file diff --git a/changelog.d/15008.misc b/changelog.d/15164.misc index 93ceaeafc9..93ceaeafc9 100644 --- a/changelog.d/15008.misc +++ b/changelog.d/15164.misc diff --git a/changelog.d/15165.misc b/changelog.d/15165.misc new file mode 100644 index 0000000000..a75be84dac --- /dev/null +++ b/changelog.d/15165.misc @@ -0,0 +1 @@ +Move `get_event_report` and `get_event_reports_paginate` from `RoomStore` to `RoomWorkerStore`. \ No newline at end of file diff --git a/changelog.d/15167.misc b/changelog.d/15167.misc new file mode 100644 index 0000000000..175c2a3b83 --- /dev/null +++ b/changelog.d/15167.misc @@ -0,0 +1 @@ +Remove dangling reference to being a reference implementation in docstring. diff --git a/changelog.d/15168.doc b/changelog.d/15168.doc new file mode 100644 index 0000000000..dbd3c54714 --- /dev/null +++ b/changelog.d/15168.doc @@ -0,0 +1 @@ +Correct the description of the behavior of `registration_shared_secret_path` on startup. diff --git a/changelog.d/15172.feature b/changelog.d/15172.feature new file mode 100644 index 0000000000..3f789edb7f --- /dev/null +++ b/changelog.d/15172.feature @@ -0,0 +1 @@ +Remove support for server-side aggregation of reactions. diff --git a/changelog.d/15174.bugfix b/changelog.d/15174.bugfix new file mode 100644 index 0000000000..a0c70cbe22 --- /dev/null +++ b/changelog.d/15174.bugfix @@ -0,0 +1 @@ +Add the `transaction_id` in the events included in many endpoints responses. diff --git a/changelog.d/15175.misc b/changelog.d/15175.misc new file mode 100644 index 0000000000..8de5f95239 --- /dev/null +++ b/changelog.d/15175.misc @@ -0,0 +1 @@ +Refactor the media modules. diff --git a/changelog.d/15177.bugfix b/changelog.d/15177.bugfix new file mode 100644 index 0000000000..b9764947eb --- /dev/null +++ b/changelog.d/15177.bugfix @@ -0,0 +1 @@ +Fix test_icu_word_boundary_punctuation for alpine / macos installed ICU versions. diff --git a/changelog.d/15180.bugfix b/changelog.d/15180.bugfix new file mode 100644 index 0000000000..e7a3dcd41a --- /dev/null +++ b/changelog.d/15180.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.78.0 where requests to claim dehydrated devices would fail with a `405` error. diff --git a/changelog.d/15185.feature b/changelog.d/15185.feature new file mode 100644 index 0000000000..901900bdec --- /dev/null +++ b/changelog.d/15185.feature @@ -0,0 +1 @@ +Stabilise support for [MSC3758](https://github.com/matrix-org/matrix-spec-proposals/pull/3758): `event_property_is` push condition. diff --git a/changelog.d/15186.docker b/changelog.d/15186.docker new file mode 100644 index 0000000000..5e436ff7e2 --- /dev/null +++ b/changelog.d/15186.docker @@ -0,0 +1 @@ +Improve startup logging in the with-workers Docker image. diff --git a/changelog.d/15188.misc b/changelog.d/15188.misc new file mode 100644 index 0000000000..e4e9472f01 --- /dev/null +++ b/changelog.d/15188.misc @@ -0,0 +1 @@ +Use nightly rustfmt in CI. diff --git a/changelog.d/15189.misc b/changelog.d/15189.misc new file mode 100644 index 0000000000..ded2feb79e --- /dev/null +++ b/changelog.d/15189.misc @@ -0,0 +1 @@ +Remove the unspecced `PUT` on the `/knock/{roomIdOrAlias}` endpoint. diff --git a/changelog.d/15191.misc b/changelog.d/15191.misc new file mode 100644 index 0000000000..579f76d451 --- /dev/null +++ b/changelog.d/15191.misc @@ -0,0 +1 @@ +Add a `get_next_txn` method to `StreamIdGenerator` to match `MultiWriterIdGenerator`. \ No newline at end of file diff --git a/changelog.d/15192.misc b/changelog.d/15192.misc new file mode 100644 index 0000000000..1076686875 --- /dev/null +++ b/changelog.d/15192.misc @@ -0,0 +1 @@ +Combine `AbstractStreamIdTracker` and `AbstractStreamIdGenerator`. diff --git a/changelog.d/15193.bugfix b/changelog.d/15193.bugfix new file mode 100644 index 0000000000..ca781e9631 --- /dev/null +++ b/changelog.d/15193.bugfix @@ -0,0 +1 @@ +Stop applying edits when bundling aggregations, per [MSC3925](https://github.com/matrix-org/matrix-spec-proposals/pull/3925). diff --git a/changelog.d/15194.misc b/changelog.d/15194.misc new file mode 100644 index 0000000000..931bf5448a --- /dev/null +++ b/changelog.d/15194.misc @@ -0,0 +1 @@ +Automatically fix errors with `ruff`. diff --git a/changelog.d/15199.misc b/changelog.d/15199.misc new file mode 100644 index 0000000000..145b03fe16 --- /dev/null +++ b/changelog.d/15199.misc @@ -0,0 +1 @@ +Remove unspecced and buggy `PUT` method on the unstable `/rooms/<room_id>/batch_send` endpoint. diff --git a/changelog.d/15214.misc b/changelog.d/15214.misc new file mode 100644 index 0000000000..91a8cb9d72 --- /dev/null +++ b/changelog.d/15214.misc @@ -0,0 +1 @@ +Bump serde_json from 1.0.93 to 1.0.94. diff --git a/debian/changelog b/debian/changelog index ea651438f1..0f094308c1 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,16 @@ +matrix-synapse-py3 (1.78.0) stable; urgency=medium + + * New Synapse release 1.78.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 28 Feb 2023 08:56:03 -0800 + +matrix-synapse-py3 (1.78.0~rc1) stable; urgency=medium + + * Add `matrix-org-archive-keyring` package as recommended. + * New Synapse release 1.78.0rc1. + + -- Synapse Packaging team <packages@matrix.org> Tue, 21 Feb 2023 14:29:19 +0000 + matrix-synapse-py3 (1.77.0) stable; urgency=medium * New Synapse release 1.77.0. diff --git a/debian/control b/debian/control index bc628cec08..2ff55db5de 100644 --- a/debian/control +++ b/debian/control @@ -37,6 +37,7 @@ Depends: # so we put perl:Depends in Suggests rather than Depends. Recommends: ${shlibs1:Recommends}, + matrix-org-archive-keyring, Suggests: sqlite3, ${perl:Depends}, diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 58c62f2231..add8bb1ff6 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -142,6 +142,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases", "^/_matrix/client/v1/rooms/.*/timestamp_to_event$", "^/_matrix/client/(api/v1|r0|v3|unstable)/search", + "^/_matrix/client/(r0|v3|unstable)/user/.*/filter(/|$)", ], "shared_extra_conf": {}, "worker_extra_conf": "", @@ -204,6 +205,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/send", "^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$", "^/_matrix/client/(api/v1|r0|v3|unstable)/join/", + "^/_matrix/client/(api/v1|r0|v3|unstable)/knock/", "^/_matrix/client/(api/v1|r0|v3|unstable)/profile/", "^/_matrix/client/(v1|unstable/org.matrix.msc2716)/rooms/.*/batch_send", ], @@ -674,17 +676,21 @@ def main(args: List[str], environ: MutableMapping[str, str]) -> None: if not os.path.exists(config_path): log("Generating base homeserver config") generate_base_homeserver_config() - + else: + log("Base homeserver config exists—not regenerating") # This script may be run multiple times (mostly by Complement, see note at top of file). # Don't re-configure workers in this instance. mark_filepath = "/conf/workers_have_been_configured" if not os.path.exists(mark_filepath): # Always regenerate all other config files + log("Generating worker config files") generate_worker_files(environ, config_path, data_dir) # Mark workers as being configured with open(mark_filepath, "w") as f: f.write("") + else: + log("Worker config exists—not regenerating") # Lifted right out of start.py jemallocpath = "/usr/lib/%s-linux-gnu/libjemalloc.so.2" % (platform.machine(),) diff --git a/docs/admin_api/event_reports.md b/docs/admin_api/event_reports.md index beec8bb7ef..83f7dc37f4 100644 --- a/docs/admin_api/event_reports.md +++ b/docs/admin_api/event_reports.md @@ -169,3 +169,17 @@ The following fields are returned in the JSON response body: * `canonical_alias`: string - The canonical alias of the room. `null` if the room does not have a canonical alias set. * `event_json`: object - Details of the original event that was reported. + +# Delete a specific event report + +This API deletes a specific event report. If the request is successful, the response body +will be an empty JSON object. + +The api is: +``` +DELETE /_synapse/admin/v1/event_reports/<report_id> +``` + +**URL parameters:** + +* `report_id`: string - The ID of the event report. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 50969edd46..1a0c6ec954 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -307,8 +307,8 @@ _Changed in Synapse v1.62.0: `synapse.module_api.NOT_SPAM` and `synapse.module_a ```python async def check_media_file_for_spam( - file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper", - file_info: "synapse.rest.media.v1._base.FileInfo", + file_wrapper: "synapse.media.media_storage.ReadableFileWrapper", + file_info: "synapse.media._base.FileInfo", ) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes", bool] ``` diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index 996b007968..d0eaf77f76 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -215,6 +215,9 @@ Note that this callback is called when the event has already been processed and into the room, which means this callback cannot be used to deny persisting the event. To deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#check_event_for_spam) instead. +For any given event, this callback will be called on every worker process, even if that worker will not end up +acting on that event. This callback will not be called for events that are marked as rejected. + If multiple modules implement this callback, Synapse runs them all in order. ### `check_can_shutdown_room` @@ -320,6 +323,11 @@ If multiple modules implement this callback, Synapse runs them all in order. _First introduced in Synapse v1.56.0_ +**<span style="color:red"> +This callback is deprecated in favour of the `on_add_user_third_party_identifier` callback, which +features the same functionality. The only difference is in name. +</span>** + ```python async def on_threepid_bind(user_id: str, medium: str, address: str) -> None: ``` @@ -334,6 +342,44 @@ server_. If multiple modules implement this callback, Synapse runs them all in order. +### `on_add_user_third_party_identifier` + +_First introduced in Synapse v1.79.0_ + +```python +async def on_add_user_third_party_identifier(user_id: str, medium: str, address: str) -> None: +``` + +Called after successfully creating an association between a user and a third-party identifier +(email address, phone number). The module is given the Matrix ID of the user the +association is for, as well as the medium (`email` or `msisdn`) and address of the +third-party identifier (i.e. an email address). + +Note that this callback is _not_ called if a user attempts to bind their third-party identifier +to an identity server (via a call to [`POST +/_matrix/client/v3/account/3pid/bind`](https://spec.matrix.org/v1.5/client-server-api/#post_matrixclientv3account3pidbind)). + +If multiple modules implement this callback, Synapse runs them all in order. + +### `on_remove_user_third_party_identifier` + +_First introduced in Synapse v1.79.0_ + +```python +async def on_remove_user_third_party_identifier(user_id: str, medium: str, address: str) -> None: +``` + +Called after successfully removing an association between a user and a third-party identifier +(email address, phone number). The module is given the Matrix ID of the user the +association is for, as well as the medium (`email` or `msisdn`) and address of the +third-party identifier (i.e. an email address). + +Note that this callback is _not_ called if a user attempts to unbind their third-party +identifier from an identity server (via a call to [`POST +/_matrix/client/v3/account/3pid/unbind`](https://spec.matrix.org/v1.5/client-server-api/#post_matrixclientv3account3pidunbind)). + +If multiple modules implement this callback, Synapse runs them all in order. + ## Example The example below is a module that implements the third-party rules callback @@ -366,4 +412,4 @@ class EventCensorer: ) event_dict["content"] = new_event_content return event_dict -``` +``` \ No newline at end of file diff --git a/docs/modules/writing_a_module.md b/docs/modules/writing_a_module.md index 30de69a533..b99f64b9d8 100644 --- a/docs/modules/writing_a_module.md +++ b/docs/modules/writing_a_module.md @@ -83,3 +83,59 @@ the callback name as the argument name and the function as its value. A Callbacks for each category can be found on their respective page of the [Synapse documentation website](https://matrix-org.github.io/synapse). + +## Caching + +_Added in Synapse 1.74.0._ + +Modules can leverage Synapse's caching tools to manage their own cached functions. This +can be helpful for modules that need to repeatedly request the same data from the database +or a remote service. + +Functions that need to be wrapped with a cache need to be decorated with a `@cached()` +decorator (which can be imported from `synapse.module_api`) and registered with the +[`ModuleApi.register_cached_function`](https://github.com/matrix-org/synapse/blob/release-v1.77/synapse/module_api/__init__.py#L888) +API when initialising the module. If the module needs to invalidate an entry in a cache, +it needs to use the [`ModuleApi.invalidate_cache`](https://github.com/matrix-org/synapse/blob/release-v1.77/synapse/module_api/__init__.py#L904) +API, with the function to invalidate the cache of and the key(s) of the entry to +invalidate. + +Below is an example of a simple module using a cached function: + +```python +from typing import Any +from synapse.module_api import cached, ModuleApi + +class MyModule: + def __init__(self, config: Any, api: ModuleApi): + self.api = api + + # Register the cached function so Synapse knows how to correctly invalidate + # entries for it. + self.api.register_cached_function(self.get_user_from_id) + + @cached() + async def get_department_for_user(self, user_id: str) -> str: + """A function with a cache.""" + # Request a department from an external service. + return await self.http_client.get_json( + "https://int.example.com/users", {"user_id": user_id) + )["department"] + + async def do_something_with_users(self) -> None: + """Calls the cached function and then invalidates an entry in its cache.""" + + user_id = "@alice:example.com" + + # Get the user. Since get_department_for_user is wrapped with a cache, + # the return value for this user_id will be cached. + department = await self.get_department_for_user(user_id) + + # Do something with `department`... + + # Let's say something has changed with our user, and the entry we have for + # them in the cache is out of date, so we want to invalidate it. + await self.api.invalidate_cache(self.get_department_for_user, (user_id,)) +``` + +See the [`cached` docstring](https://github.com/matrix-org/synapse/blob/release-v1.77/synapse/module_api/__init__.py#L190) for more details. diff --git a/docs/openid.md b/docs/openid.md index 6ee8c83ec0..73f1e06121 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -590,6 +590,47 @@ oidc_providers: Note that the fields `client_id` and `client_secret` are taken from the CURL response above. +### Shibboleth with OIDC Plugin + +[Shibboleth](https://www.shibboleth.net/) is an open Standard IdP solution widely used by Universities. + +1. Shibboleth needs the [OIDC Plugin](https://shibboleth.atlassian.net/wiki/spaces/IDPPLUGINS/pages/1376878976/OIDC+OP) installed and working correctly. +2. Create a new config on the IdP Side, ensure that the `client_id` and `client_secret` + are randomly generated data. +```json +{ + "client_id": "SOME-CLIENT-ID", + "client_secret": "SOME-SUPER-SECRET-SECRET", + "response_types": ["code"], + "grant_types": ["authorization_code"], + "scope": "openid profile email", + "redirect_uris": ["https://[synapse public baseurl]/_synapse/client/oidc/callback"] +} +``` + +Synapse config: + +```yaml +oidc_providers: + # Shibboleth IDP + # + - idp_id: shibboleth + idp_name: "Shibboleth Login" + discover: true + issuer: "https://YOUR-IDP-URL.TLD" + client_id: "YOUR_CLIENT_ID" + client_secret: "YOUR-CLIENT-SECRECT-FROM-YOUR-IDP" + scopes: ["openid", "profile", "email"] + allow_existing_users: true + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + subject_claim: "sub" + localpart_template: "{{ user.sub.split('@')[0] }}" + display_name_template: "{{ user.name }}" + email_template: "{{ user.email }}" +``` + ### Twitch 1. Setup a developer account on [Twitch](https://dev.twitch.tv/) diff --git a/docs/upgrade.md b/docs/upgrade.md index 15167b8c58..f06e874054 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,30 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.79.0 + +## The `on_threepid_bind` module callback method has been deprecated + +Synapse v1.79.0 deprecates the +[`on_threepid_bind`](modules/third_party_rules_callbacks.md#on_threepid_bind) +"third-party rules" Synapse module callback method in favour of a new module method, +[`on_add_user_third_party_identifier`](modules/third_party_rules_callbacks.md#on_add_user_third_party_identifier). +`on_threepid_bind` will be removed in a future version of Synapse. You should check whether any Synapse +modules in use in your deployment are making use of `on_threepid_bind`, and update them where possible. + +The arguments and functionality of the new method are the same. + +The justification behind the name change is that the old method's name, `on_threepid_bind`, was +misleading. A user is considered to "bind" their third-party ID to their Matrix ID only if they +do so via an [identity server](https://spec.matrix.org/latest/identity-service-api/) +(so that users on other homeservers may find them). But this method was not called in that case - +it was only called when a user added a third-party identifier on the local homeserver. + +Module developers may also be interested in the related +[`on_remove_user_third_party_identifier`](modules/third_party_rules_callbacks.md#on_remove_user_third_party_identifier) +module callback method that was also added in Synapse v1.79.0. This new method is called when a +user removes a third-party identifier from their account. + # Upgrading to v1.78.0 ## Deprecate the `/_synapse/admin/v1/media/<server_name>/delete` admin API diff --git a/docs/usage/administration/admin_faq.md b/docs/usage/administration/admin_faq.md index 7a27741199..28c3dd53a5 100644 --- a/docs/usage/administration/admin_faq.md +++ b/docs/usage/administration/admin_faq.md @@ -70,10 +70,55 @@ output-directory │ ├───state │ ├───invite_state │ └───knock_state -└───user_data - ├───connections - ├───devices - └───profile +├───user_data +│ ├───account_data +│ │ ├───global +│ │ └───<room_id> +│ ├───connections +│ ├───devices +│ └───profile +└───media_ids + └───<media_id> +``` + +The `media_ids` folder contains only the metadata of the media uploaded by the user. +It does not contain the media itself. +Furthermore, only the `media_ids` that Synapse manages itself are exported. +If another media repository (e.g. [matrix-media-repo](https://github.com/turt2live/matrix-media-repo)) +is used, the data must be exported separately. + +With the `media_ids` the media files can be downloaded. +Media that have been sent in encrypted rooms are only retrieved in encrypted form. +The following script can help with download the media files: + +```bash +#!/usr/bin/env bash + +# Parameters +# +# source_directory: Directory which contains the export with the media_ids. +# target_directory: Directory into which all files are to be downloaded. +# repository_url: Address of the media repository resp. media worker. +# serverName: Name of the server (`server_name` from homeserver.yaml). +# +# Example: +# ./download_media.sh /tmp/export_data/media_ids/ /tmp/export_data/media_files/ http://localhost:8008 matrix.example.com + +source_directory=$1 +target_directory=$2 +repository_url=$3 +serverName=$4 + +mkdir -p $target_directory + +for file in $source_directory/*; do + filename=$(basename ${file}) + url=$repository_url/_matrix/media/v3/download/$serverName/$filename + echo "Downloading $filename - $url" + if ! wget -o /dev/null -P $target_directory $url; then + echo "Could not download $filename" + fi +done ``` Manually resetting passwords @@ -84,7 +129,7 @@ can reset a user's password using the [admin API](../../admin_api/user_admin_api I have a problem with my server. Can I just delete my database and start again? --- -Deleting your database is unlikely to make anything better. +Deleting your database is unlikely to make anything better. It's easy to make the mistake of thinking that you can start again from a clean slate by dropping your database, but things don't work like that in a federated @@ -99,7 +144,7 @@ Come and seek help in https://matrix.to/#/#synapse:matrix.org. There are two exceptions when it might be sensible to delete your database and start again: * You have *never* joined any rooms which are federated with other servers. For -instance, a local deployment which the outside world can't talk to. +instance, a local deployment which the outside world can't talk to. * You are changing the `server_name` in the homeserver configuration. In effect this makes your server a completely new one from the point of view of the network, so in this case it makes sense to start with a clean database. @@ -112,7 +157,7 @@ Using the following curl command: curl -H 'Authorization: Bearer <access-token>' -X DELETE https://matrix.org/_matrix/client/r0/directory/room/<room-alias> ``` `<access-token>` - can be obtained in riot by looking in the riot settings, down the bottom is: -Access Token:\<click to reveal\> +Access Token:\<click to reveal\> `<room-alias>` - the room alias, eg. #my_room:matrix.org this possibly needs to be URL encoded also, for example %23my_room%3Amatrix.org @@ -149,13 +194,13 @@ What are the biggest rooms on my server? --- ```sql -SELECT s.canonical_alias, g.room_id, count(*) AS num_rows -FROM - state_groups_state AS g, - room_stats_state AS s -WHERE g.room_id = s.room_id +SELECT s.canonical_alias, g.room_id, count(*) AS num_rows +FROM + state_groups_state AS g, + room_stats_state AS s +WHERE g.room_id = s.room_id GROUP BY s.canonical_alias, g.room_id -ORDER BY num_rows desc +ORDER BY num_rows desc LIMIT 10; ``` diff --git a/docs/usage/administration/database_maintenance_tools.md b/docs/usage/administration/database_maintenance_tools.md index 92b805d413..e19380db07 100644 --- a/docs/usage/administration/database_maintenance_tools.md +++ b/docs/usage/administration/database_maintenance_tools.md @@ -1,4 +1,4 @@ -This blog post by Victor Berger explains how to use many of the tools listed on this page: https://levans.fr/shrink-synapse-database.html +_This [blog post by Jackson Chen](https://jacksonchen666.com/posts/2022-12-03/14-33-00/) (Dec 2022) explains how to use many of the tools listed on this page. There is also an [earlier blog by Victor Berger](https://levans.fr/shrink-synapse-database.html) (June 2020), though this may be outdated in places._ # List of useful tools and scripts for maintenance Synapse database: @@ -15,4 +15,4 @@ The purge history API allows server admins to purge historic events from their d Tool for compressing (deduplicating) `state_groups_state` table. ## [SQL for analyzing Synapse PostgreSQL database stats](useful_sql_for_admins.md) -Some easy SQL that reports useful stats about your Synapse database. \ No newline at end of file +Some easy SQL that reports useful stats about your Synapse database. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 75483bfb12..015855ee7e 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1105,7 +1105,7 @@ This setting should only be used in very specific cases, such as federation over Tor hidden services and similar. For private networks of homeservers, you likely want to use a private CA instead. -Only effective if `federation_verify_certicates` is `true`. +Only effective if `federation_verify_certificates` is `true`. Example configuration: ```yaml @@ -1518,11 +1518,11 @@ rc_registration_token_validity: This option specifies several limits for login: * `address` ratelimits login requests based on the client's IP - address. Defaults to `per_second: 0.17`, `burst_count: 3`. + address. Defaults to `per_second: 0.003`, `burst_count: 5`. * `account` ratelimits login requests based on the account the - client is attempting to log into. Defaults to `per_second: 0.17`, - `burst_count: 3`. + client is attempting to log into. Defaults to `per_second: 0.03`, + `burst_count: 5`. * `failed_attempts` ratelimits login requests based on the account the client is attempting to log into, based on the amount of failed login @@ -2227,12 +2227,12 @@ allows the shared secret to be specified in an external file. The file should be a plain text file, containing only the shared secret. -If this file does not exist, Synapse will create a new signing -key on startup and store it in this file. +If this file does not exist, Synapse will create a new shared +secret on startup and store it in this file. Example configuration: ```yaml -registration_shared_secret_file: /path/to/secrets/file +registration_shared_secret_path: /path/to/secrets/file ``` _Added in Synapse 1.67.0._ diff --git a/docs/workers.md b/docs/workers.md index bc66f0e1bc..fa536cd310 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -160,7 +160,18 @@ recommend the use of `systemd` where available: for information on setting up [Systemd with Workers](systemd-with-workers/). To use `synctl`, see [Using synctl with Workers](synctl_workers.md). +## Start Synapse with Poetry +The following applies to Synapse installations that have been installed from source using `poetry`. + +You can start the main Synapse process with Poetry by running the following command: +```console +poetry run synapse_homeserver -c [your homeserver.yaml] +``` +For worker setups, you can run the following command +```console +poetry run synapse_worker -c [your worker.yaml] +``` ## Available worker applications ### `synapse.app.generic_worker` @@ -221,6 +232,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/v1/rooms/.*/timestamp_to_event$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ + ^/_matrix/client/(r0|v3|unstable)/user/.*/filter(/|$) # Encryption requests ^/_matrix/client/(r0|v3|unstable)/keys/query$ @@ -240,6 +252,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ ^/_matrix/client/(api/v1|r0|v3|unstable)/join/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/knock/ ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ # Account data requests diff --git a/mypy.ini b/mypy.ini index ff6e04b12f..572734f8e7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,16 +31,11 @@ exclude = (?x) |synapse/storage/databases/__init__.py |synapse/storage/databases/main/cache.py |synapse/storage/schema/ - - |tests/server.py )$ [mypy-synapse.federation.transport.client] disallow_untyped_defs = False -[mypy-synapse.http.client] -disallow_untyped_defs = False - [mypy-synapse.http.matrixfederationclient] disallow_untyped_defs = False diff --git a/poetry.lock b/poetry.lock index e534b30d2b..cd3dc6fdcd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -90,32 +90,46 @@ typecheck = ["mypy"] [[package]] name = "black" -version = "22.12.0" +version = "23.1.0" description = "The uncompromising code formatter." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, - {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, - {file = "black-22.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d30b212bffeb1e252b31dd269dfae69dd17e06d92b87ad26e23890f3efea366f"}, - {file = "black-22.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:7412e75863aa5c5411886804678b7d083c7c28421210180d67dfd8cf1221e1f4"}, - {file = "black-22.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c116eed0efb9ff870ded8b62fe9f28dd61ef6e9ddd28d83d7d264a38417dcee2"}, - {file = "black-22.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1f58cbe16dfe8c12b7434e50ff889fa479072096d79f0a7f25e4ab8e94cd8350"}, - {file = "black-22.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77d86c9f3db9b1bf6761244bc0b3572a546f5fe37917a044e02f3166d5aafa7d"}, - {file = "black-22.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d9fe8fee3401e02e79767016b4907820a7dc28d70d137eb397b92ef3cc5bfc"}, - {file = "black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101c69b23df9b44247bd88e1d7e90154336ac4992502d4197bdac35dd7ee3320"}, - {file = "black-22.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:559c7a1ba9a006226f09e4916060982fd27334ae1998e7a38b3f33a37f7a2148"}, - {file = "black-22.12.0-py3-none-any.whl", hash = "sha256:436cc9167dd28040ad90d3b404aec22cedf24a6e4d7de221bec2730ec0c97bcf"}, - {file = "black-22.12.0.tar.gz", hash = "sha256:229351e5a18ca30f447bf724d007f890f97e13af070bb6ad4c0a441cd7596a2f"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:b6a92a41ee34b883b359998f0c8e6eb8e99803aa8bf3123bf2b2e6fec505a221"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:57c18c5165c1dbe291d5306e53fb3988122890e57bd9b3dcb75f967f13411a26"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:9880d7d419bb7e709b37e28deb5e68a49227713b623c72b2b931028ea65f619b"}, + {file = "black-23.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6663f91b6feca5d06f2ccd49a10f254f9298cc1f7f49c46e498a0771b507104"}, + {file = "black-23.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9afd3f493666a0cd8f8df9a0200c6359ac53940cbde049dcb1a7eb6ee2dd7074"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:bfffba28dc52a58f04492181392ee380e95262af14ee01d4bc7bb1b1c6ca8d27"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c1c476bc7b7d021321e7d93dc2cbd78ce103b84d5a4cf97ed535fbc0d6660648"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:382998821f58e5c8238d3166c492139573325287820963d2f7de4d518bd76958"}, + {file = "black-23.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bf649fda611c8550ca9d7592b69f0637218c2369b7744694c5e4902873b2f3a"}, + {file = "black-23.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:121ca7f10b4a01fd99951234abdbd97728e1240be89fde18480ffac16503d481"}, + {file = "black-23.1.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:a8471939da5e824b891b25751955be52ee7f8a30a916d570a5ba8e0f2eb2ecad"}, + {file = "black-23.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8178318cb74f98bc571eef19068f6ab5613b3e59d4f47771582f04e175570ed8"}, + {file = "black-23.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a436e7881d33acaf2536c46a454bb964a50eff59b21b51c6ccf5a40601fbef24"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:a59db0a2094d2259c554676403fa2fac3473ccf1354c1c63eccf7ae65aac8ab6"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:0052dba51dec07ed029ed61b18183942043e00008ec65d5028814afaab9a22fd"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:49f7b39e30f326a34b5c9a4213213a6b221d7ae9d58ec70df1c4a307cf2a1580"}, + {file = "black-23.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:162e37d49e93bd6eb6f1afc3e17a3d23a823042530c37c3c42eeeaf026f38468"}, + {file = "black-23.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b70eb40a78dfac24842458476135f9b99ab952dd3f2dab738c1881a9b38b753"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:a29650759a6a0944e7cca036674655c2f0f63806ddecc45ed40b7b8aa314b651"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:bb460c8561c8c1bec7824ecbc3ce085eb50005883a6203dcfb0122e95797ee06"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c91dfc2c2a4e50df0026f88d2215e166616e0c80e86004d0003ece0488db2739"}, + {file = "black-23.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a951cc83ab535d248c89f300eccbd625e80ab880fbcfb5ac8afb5f01a258ac9"}, + {file = "black-23.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:0680d4380db3719ebcfb2613f34e86c8e6d15ffeabcf8ec59355c5e7b85bb555"}, + {file = "black-23.1.0-py3-none-any.whl", hash = "sha256:7a0f701d314cfa0896b9001df70a530eb2472babb76086344e688829efd97d32"}, + {file = "black-23.1.0.tar.gz", hash = "sha256:b0bd97bea8903f5a2ba7219257a44e3f1f9d00073d6cc1add68f0beec69692ac"}, ] [package.dependencies] click = ">=8.0.0" mypy-extensions = ">=0.4.3" +packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} @@ -146,14 +160,14 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] [[package]] name = "canonicaljson" -version = "1.6.4" +version = "1.6.5" description = "Canonical JSON" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "canonicaljson-1.6.4-py3-none-any.whl", hash = "sha256:55d282853b4245dbcd953fe54c39b91571813d7c44e1dbf66e3c4f97ff134a48"}, - {file = "canonicaljson-1.6.4.tar.gz", hash = "sha256:6c09b2119511f30eb1126cfcd973a10824e20f1cfd25039cde3d1218dd9c8d8f"}, + {file = "canonicaljson-1.6.5-py3-none-any.whl", hash = "sha256:806ea6f2cbb7405d20259e1c36dd1214ba5c242fa9165f5bd0bf2081f82c23fb"}, + {file = "canonicaljson-1.6.5.tar.gz", hash = "sha256:68dfc157b011e07d94bf74b5d4ccc01958584ed942d9dfd5fdd706609e81cd4b"}, ] [package.dependencies] @@ -1146,36 +1160,38 @@ files = [ [[package]] name = "mypy" -version = "0.981" +version = "1.0.0" description = "Optional static typing for Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"}, - {file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"}, - {file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"}, - {file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"}, - {file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"}, - {file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"}, - {file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"}, - {file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"}, - {file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"}, - {file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"}, - {file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"}, - {file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"}, - {file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"}, - {file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"}, - {file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"}, - {file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"}, - {file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"}, - {file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"}, - {file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"}, - {file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"}, - {file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"}, - {file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"}, + {file = "mypy-1.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0626db16705ab9f7fa6c249c017c887baf20738ce7f9129da162bb3075fc1af"}, + {file = "mypy-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1ace23f6bb4aec4604b86c4843276e8fa548d667dbbd0cb83a3ae14b18b2db6c"}, + {file = "mypy-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87edfaf344c9401942883fad030909116aa77b0fa7e6e8e1c5407e14549afe9a"}, + {file = "mypy-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0ab090d9240d6b4e99e1fa998c2d0aa5b29fc0fb06bd30e7ad6183c95fa07593"}, + {file = "mypy-1.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:7cc2c01dfc5a3cbddfa6c13f530ef3b95292f926329929001d45e124342cd6b7"}, + {file = "mypy-1.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14d776869a3e6c89c17eb943100f7868f677703c8a4e00b3803918f86aafbc52"}, + {file = "mypy-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb2782a036d9eb6b5a6efcdda0986774bf798beef86a62da86cb73e2a10b423d"}, + {file = "mypy-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cfca124f0ac6707747544c127880893ad72a656e136adc935c8600740b21ff5"}, + {file = "mypy-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8845125d0b7c57838a10fd8925b0f5f709d0e08568ce587cc862aacce453e3dd"}, + {file = "mypy-1.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b1b9e1ed40544ef486fa8ac022232ccc57109f379611633ede8e71630d07d2"}, + {file = "mypy-1.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c7cf862aef988b5fbaa17764ad1d21b4831436701c7d2b653156a9497d92c83c"}, + {file = "mypy-1.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd187d92b6939617f1168a4fe68f68add749902c010e66fe574c165c742ed88"}, + {file = "mypy-1.0.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4e5175026618c178dfba6188228b845b64131034ab3ba52acaffa8f6c361f805"}, + {file = "mypy-1.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2f6ac8c87e046dc18c7d1d7f6653a66787a4555085b056fe2d599f1f1a2a2d21"}, + {file = "mypy-1.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7306edca1c6f1b5fa0bc9aa645e6ac8393014fa82d0fa180d0ebc990ebe15964"}, + {file = "mypy-1.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3cfad08f16a9c6611e6143485a93de0e1e13f48cfb90bcad7d5fde1c0cec3d36"}, + {file = "mypy-1.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67cced7f15654710386e5c10b96608f1ee3d5c94ca1da5a2aad5889793a824c1"}, + {file = "mypy-1.0.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a86b794e8a56ada65c573183756eac8ac5b8d3d59daf9d5ebd72ecdbb7867a43"}, + {file = "mypy-1.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:50979d5efff8d4135d9db293c6cb2c42260e70fb010cbc697b1311a4d7a39ddb"}, + {file = "mypy-1.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ae4c7a99e5153496243146a3baf33b9beff714464ca386b5f62daad601d87af"}, + {file = "mypy-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e398652d005a198a7f3c132426b33c6b85d98aa7dc852137a2a3be8890c4072"}, + {file = "mypy-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be78077064d016bc1b639c2cbcc5be945b47b4261a4f4b7d8923f6c69c5c9457"}, + {file = "mypy-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92024447a339400ea00ac228369cd242e988dd775640755fa4ac0c126e49bb74"}, + {file = "mypy-1.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fe523fcbd52c05040c7bee370d66fee8373c5972171e4fbc323153433198592d"}, + {file = "mypy-1.0.0-py3-none-any.whl", hash = "sha256:2efa963bdddb27cb4a0d42545cd137a8d2b883bd181bbc4525b568ef6eca258f"}, + {file = "mypy-1.0.0.tar.gz", hash = "sha256:f34495079c8d9da05b183f9f7daec2878280c2ad7cc81da686ef0b484cea2ecf"}, ] [package.dependencies] @@ -1186,6 +1202,7 @@ typing-extensions = ">=3.10" [package.extras] dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] python2 = ["typed-ast (>=1.4.0,<2)"] reports = ["lxml"] @@ -1203,18 +1220,18 @@ files = [ [[package]] name = "mypy-zope" -version = "0.3.11" +version = "0.9.0" description = "Plugin for mypy to support zope interfaces" category = "dev" optional = false python-versions = "*" files = [ - {file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"}, - {file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"}, + {file = "mypy-zope-0.9.0.tar.gz", hash = "sha256:88bf6cd056e38b338e6956055958a7805b4ff84404ccd99e29883a3647a1aeb3"}, + {file = "mypy_zope-0.9.0-py3-none-any.whl", hash = "sha256:e1bb4b57084f76ff8a154a3e07880a1af2ac6536c491dad4b143d529f72c5d15"}, ] [package.dependencies] -mypy = "0.981" +mypy = "1.0.0" "zope.interface" = "*" "zope.schema" = "*" @@ -1968,28 +1985,29 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] [[package]] name = "ruff" -version = "0.0.237" +version = "0.0.252" description = "An extremely fast Python linter, written in Rust." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.237-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:2ea04d826ffca58a7ae926115a801960c757d53c9027f2ca9acbe84c9f2b2f04"}, - {file = "ruff-0.0.237-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:8ed113937fab9f73f8c1a6c0350bb4fe03e951370139c6e0adb81f48a8dcf4c6"}, - {file = "ruff-0.0.237-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9bcb71a3efb5fe886eb48d739cfae5df4a15617e7b5a7668aa45ebf74c0d3fa"}, - {file = "ruff-0.0.237-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:80ce10718abbf502818c0d650ebab99fdcef5e937a1ded3884493ddff804373c"}, - {file = "ruff-0.0.237-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0cc6cb7c1efcc260df5a939435649610a28f9f438b8b313384c8985ac6574f9f"}, - {file = "ruff-0.0.237-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7eef0c7a1e45a4e30328ae101613575944cbf47a3a11494bf9827722da6c66b3"}, - {file = "ruff-0.0.237-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0d122433a21ce4a21fbba34b73fc3add0ccddd1643b3ff5abb8d2767952f872e"}, - {file = "ruff-0.0.237-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b76311335adda4de3c1d471e64e89a49abfeebf02647e3db064e7740e7f36ed6"}, - {file = "ruff-0.0.237-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c5977b643aaf2b6f84641265f835b6c7f67fcca38dbae08c4f15602e084ca0"}, - {file = "ruff-0.0.237-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3d6ed86d0d4d742360a262d52191581f12b669a68e59ae3b52e80d7483b3d7b3"}, - {file = "ruff-0.0.237-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fedfb60f986c26cdb1809db02866e68508db99910c587d2c4066a5c07aa85593"}, - {file = "ruff-0.0.237-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bb96796be5919871fa9ae7e88968ba9e14306d9a3f217ca6c204f68a5abeccdd"}, - {file = "ruff-0.0.237-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ea239cfedf67b74ea4952e1074bb99a4281c2145441d70bc7e2f058d5c49f1c9"}, - {file = "ruff-0.0.237-py3-none-win32.whl", hash = "sha256:8d6a1d21ae15da2b1dcffeee2606e90de0e6717e72957da7d16ab6ae18dd0058"}, - {file = "ruff-0.0.237-py3-none-win_amd64.whl", hash = "sha256:525e5ec81cee29b993f77976026a6bf44528a14aa6edb1ef47bd8079147395ae"}, - {file = "ruff-0.0.237.tar.gz", hash = "sha256:630c575f543733adf6c19a11d9a02ca9ecc364bd7140af8a4c854d4728be6b56"}, + {file = "ruff-0.0.252-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:349367a227c4db7abbc3a9993efea8a608b5bea4bb4a1e5fc6f0d56819524f92"}, + {file = "ruff-0.0.252-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:ce77f9106d96b4faf7865860fb5155b9deaf6f699d9c279118c5ad947739ecaf"}, + {file = "ruff-0.0.252-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edadb0b050293b4e60dab979ba6a4e734d9c899cbe316a0ee5b65e3cdd39c750"}, + {file = "ruff-0.0.252-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4efdae98937d1e4d23ab0b7fc7e8e6b6836cc7d2d42238ceeacbc793ef780542"}, + {file = "ruff-0.0.252-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8546d879f7d3f669379a03e7b103d90e11901976ab508aeda59c03dfd8a359e"}, + {file = "ruff-0.0.252-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:83fdc7169b6c1fb5fe8d1cdf345697f558c1b433ef97df9ca11defa2a8f3ee9e"}, + {file = "ruff-0.0.252-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84ed9be1a17e2a556a571a5b959398633dd10910abd8dcf8b098061e746e892d"}, + {file = "ruff-0.0.252-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f5e77bd9ba4438cf2ee32154e2673afe22f538ef29f5d65ca47e3dc46c42cf8"}, + {file = "ruff-0.0.252-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a5179b94b45c0f8512eaff3ab304c14714a46df2e9ca72a9d96084adc376b71"}, + {file = "ruff-0.0.252-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:92efd8a71157595df5bc46aaaa0613d8a2fbc5cddc53ae7b749c16025c324732"}, + {file = "ruff-0.0.252-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd350fc10832cfd28e681d829a8aa83ea3e653326e0ea9d98637dfb8d46177d2"}, + {file = "ruff-0.0.252-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f119240c9631216e846166e06023b1d878e25fbac93bf20da50069e91cfbfaee"}, + {file = "ruff-0.0.252-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5c5a49f89f5ede93d16eddfeeadd7e5739ec703e8f63ac95eac30236b9e49da3"}, + {file = "ruff-0.0.252-py3-none-win32.whl", hash = "sha256:89a897dc743f2fe063483ea666097e72e848f4bbe40493fe0533e61799959f6e"}, + {file = "ruff-0.0.252-py3-none-win_amd64.whl", hash = "sha256:cdc89ad6ff88519b1fb1816ac82a9ad910762c90ff5fd64dda7691b72d36aff7"}, + {file = "ruff-0.0.252-py3-none-win_arm64.whl", hash = "sha256:4b594a17cf53077165429486650658a0e1b2ac6ab88954f5afd50d2b1b5657a9"}, + {file = "ruff-0.0.252.tar.gz", hash = "sha256:6992611ab7bdbe7204e4831c95ddd3febfeece2e6f5e44bbed044454c7db0f63"}, ] [[package]] @@ -2545,14 +2563,14 @@ files = [ [[package]] name = "types-bleach" -version = "5.0.3.1" +version = "6.0.0.0" description = "Typing stubs for bleach" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-bleach-5.0.3.1.tar.gz", hash = "sha256:ce8772ea5126dab1883851b41e3aeff229aa5213ced36096990344e632e92373"}, - {file = "types_bleach-5.0.3.1-py3-none-any.whl", hash = "sha256:af5f1b3a54ff279f54c29eccb2e6988ebb6718bc4061469588a5fd4880a79287"}, + {file = "types-bleach-6.0.0.0.tar.gz", hash = "sha256:770ce9c7ea6173743ef1a4a70f2619bb1819bf53c7cd0336d939af93f488fbe2"}, + {file = "types_bleach-6.0.0.0-py3-none-any.whl", hash = "sha256:75f55f035837c5fce2cd0bd5162a2a90057680a89c9275588a5c12f5f597a14a"}, ] [[package]] @@ -2584,18 +2602,6 @@ types-enum34 = "*" types-ipaddress = "*" [[package]] -name = "types-docutils" -version = "0.19.1.1" -description = "Typing stubs for docutils" -category = "dev" -optional = false -python-versions = "*" -files = [ - {file = "types-docutils-0.19.1.1.tar.gz", hash = "sha256:be0a51ba1c7dd215d9d2df66d6845e63c1009b4bbf4c5beb87a0d9745cdba962"}, - {file = "types_docutils-0.19.1.1-py3-none-any.whl", hash = "sha256:a024cada35f0c13cc45eb0b68a102719018a634013690b7fef723bcbfadbd1f1"}, -] - -[[package]] name = "types-enum34" version = "1.1.8" description = "Typing stubs for enum34" @@ -2621,38 +2627,38 @@ files = [ [[package]] name = "types-jsonschema" -version = "4.17.0.3" +version = "4.17.0.5" description = "Typing stubs for jsonschema" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-jsonschema-4.17.0.3.tar.gz", hash = "sha256:746aa466ffed9a1acc7bdbd0ac0b5e068f00be2ee008c1d1e14b0944a8c8b24b"}, - {file = "types_jsonschema-4.17.0.3-py3-none-any.whl", hash = "sha256:c8d5b26b7c8da6a48d7fb1ce029b97e0ff6e74db3727efb968c69f39ad013685"}, + {file = "types-jsonschema-4.17.0.5.tar.gz", hash = "sha256:7adc7bfca4afe291de0c93eca9367aa72a4fbe8ce87fe15642c600ad97d45dd6"}, + {file = "types_jsonschema-4.17.0.5-py3-none-any.whl", hash = "sha256:79ac8a7763fe728947af90a24168b91621edf7e8425bf3670abd4ea0d4758fba"}, ] [[package]] name = "types-opentracing" -version = "2.4.10.1" +version = "2.4.10.3" description = "Typing stubs for opentracing" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-opentracing-2.4.10.1.tar.gz", hash = "sha256:49e7e52b8b6e221865a9201fc8c2df0bcda8e7098d4ebb35903dbfa4b4d29195"}, - {file = "types_opentracing-2.4.10.1-py3-none-any.whl", hash = "sha256:eb63394acd793e7d9e327956242349fee14580a87c025408dc268d4dd883cc24"}, + {file = "types-opentracing-2.4.10.3.tar.gz", hash = "sha256:b277f114265b41216714f9c77dffcab57038f1730fd141e2c55c5c9f6f2caa87"}, + {file = "types_opentracing-2.4.10.3-py3-none-any.whl", hash = "sha256:60244d718fcd9de7043645ecaf597222d550432507098ab2e6268f7b589a7fa7"}, ] [[package]] name = "types-pillow" -version = "9.4.0.10" +version = "9.4.0.13" description = "Typing stubs for Pillow" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-Pillow-9.4.0.10.tar.gz", hash = "sha256:341c2345610bba452d1724757c7b997a60f593cf003c101ba239db003a0ae389"}, - {file = "types_Pillow-9.4.0.10-py3-none-any.whl", hash = "sha256:302ce81cfb61aacc8983a3a2ec682cbef66522a2fe8e640f648ac2e3d6f6af53"}, + {file = "types-Pillow-9.4.0.13.tar.gz", hash = "sha256:4510aa98a28947bf63f2b29edebbd11b7cff8647d90b867cec9b3674c0a8c321"}, + {file = "types_Pillow-9.4.0.13-py3-none-any.whl", hash = "sha256:14a8a19021b8fe569a9fef9edc64a8d8a4aef340e38669d4fb3dc05cfd941130"}, ] [[package]] @@ -2711,19 +2717,16 @@ types-urllib3 = "<1.27" [[package]] name = "types-setuptools" -version = "67.1.0.0" +version = "67.4.0.3" description = "Typing stubs for setuptools" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-setuptools-67.1.0.0.tar.gz", hash = "sha256:162a39d22e3a5eb802197c84f16b19e798101bbd33d9437837fbb45627da5627"}, - {file = "types_setuptools-67.1.0.0-py3-none-any.whl", hash = "sha256:5bd7a10d93e468bfcb10d24cb8ea5e12ac4f4ac91267293959001f1448cf0619"}, + {file = "types-setuptools-67.4.0.3.tar.gz", hash = "sha256:19e958dfdbf1c5a628e54c2a7ee84935051afb7278d0c1cdb08ac194757ee3b1"}, + {file = "types_setuptools-67.4.0.3-py3-none-any.whl", hash = "sha256:3c83c3a6363dd3ddcdd054796705605f0fa8b8e5a39390e07a05e5f7af054978"}, ] -[package.dependencies] -types-docutils = "*" - [[package]] name = "types-urllib3" version = "1.26.10" @@ -2738,14 +2741,14 @@ files = [ [[package]] name = "typing-extensions" -version = "4.4.0" +version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, - {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, + {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"}, + {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, ] [[package]] @@ -3027,4 +3030,4 @@ user-search = ["pyicu"] [metadata] lock-version = "2.0" python-versions = "^3.7.1" -content-hash = "95cb043fa56e1e3275ba7f74b68b2191bd5886eea3e06b8cd370d7fc9fea3c07" +content-hash = "7bcffef7b6e6d4b1113222e2ca152b3798c997872789c8a1ea01238f199d56fe" diff --git a/pyproject.toml b/pyproject.toml index 7f74b552c1..27785b6e13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml" [tool.poetry] name = "matrix-synapse" -version = "1.77.0" +version = "1.78.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors <packages@matrix.org>"] license = "Apache-2.0" @@ -154,7 +154,9 @@ python = "^3.7.1" # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0 jsonschema = ">=3.0.0" # frozendict 2.1.2 is broken on Debian 10: https://github.com/Marco-Sulla/python-frozendict/issues/41 -frozendict = ">=1,!=2.1.2" +# We cannot test our wheels against the 2.3.5 release in CI. Putting in an upper bound for this +# because frozendict has been more trouble than it's worth; we would like to move to immutabledict. +frozendict = ">=1,!=2.1.2,<2.3.5" # We require 2.1.0 or higher for type hints. Previous guard was >= 1.1.0 unpaddedbase64 = ">=2.1.0" # We require 1.5.0 to work around an issue when running against the C implementation of @@ -311,7 +313,7 @@ all = [ # We pin black so that our tests don't start failing on new releases. isort = ">=5.10.1" black = ">=22.3.0" -ruff = "0.0.237" +ruff = "0.0.252" # Typechecking mypy = "*" diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 8213dfd9ea..79b553dbb0 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -14,6 +14,7 @@ #![feature(test)] use std::collections::BTreeSet; + use synapse::push::{ evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, JsonValue, PushRules, SimpleJsonValue, @@ -44,8 +45,6 @@ fn bench_match_exact(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, false, - BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -54,15 +53,13 @@ fn bench_match_exact(b: &mut Bencher) { vec![], false, false, - false, ) .unwrap(); let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( EventMatchCondition { key: "room_id".into(), - pattern: Some("!room:server".into()), - pattern_type: None, + pattern: "!room:server".into(), }, )); @@ -94,8 +91,6 @@ fn bench_match_word(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, false, - BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -104,15 +99,13 @@ fn bench_match_word(b: &mut Bencher) { vec![], false, false, - false, ) .unwrap(); let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( EventMatchCondition { key: "content.body".into(), - pattern: Some("test".into()), - pattern_type: None, + pattern: "test".into(), }, )); @@ -144,8 +137,6 @@ fn bench_match_word_miss(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, false, - BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -154,15 +145,13 @@ fn bench_match_word_miss(b: &mut Bencher) { vec![], false, false, - false, ) .unwrap(); let condition = Condition::Known(synapse::push::KnownCondition::EventMatch( EventMatchCondition { key: "content.body".into(), - pattern: Some("foobar".into()), - pattern_type: None, + pattern: "foobar".into(), }, )); @@ -194,8 +183,6 @@ fn bench_eval_message(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, false, - BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -204,7 +191,6 @@ fn bench_eval_message(b: &mut Bencher) { vec![], false, false, - false, ) .unwrap(); diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index dcbca340fe..ec8d96656a 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -21,13 +21,13 @@ use lazy_static::lazy_static; use serde_json::Value; use super::KnownCondition; -use crate::push::Action; -use crate::push::Condition; -use crate::push::EventMatchCondition; -use crate::push::PushRule; -use crate::push::RelatedEventMatchCondition; +use crate::push::RelatedEventMatchTypeCondition; use crate::push::SetTweak; use crate::push::TweakValue; +use crate::push::{Action, EventPropertyIsCondition, SimpleJsonValue}; +use crate::push::{Condition, EventMatchTypeCondition}; +use crate::push::{EventMatchCondition, EventMatchPatternType}; +use crate::push::{EventPropertyIsTypeCondition, PushRule}; const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak { set_tweak: Cow::Borrowed("highlight"), @@ -72,8 +72,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("content.m.relates_to.rel_type"), - pattern: Some(Cow::Borrowed("m.replace")), - pattern_type: None, + pattern: Cow::Borrowed("m.replace"), }, ))]), actions: Cow::Borrowed(&[]), @@ -86,8 +85,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("content.msgtype"), - pattern: Some(Cow::Borrowed("m.notice")), - pattern_type: None, + pattern: Cow::Borrowed("m.notice"), }, ))]), actions: Cow::Borrowed(&[Action::DontNotify]), @@ -100,18 +98,15 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.member")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.member"), })), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("content.membership"), - pattern: Some(Cow::Borrowed("invite")), - pattern_type: None, + pattern: Cow::Borrowed("invite"), })), - Condition::Known(KnownCondition::EventMatch(EventMatchCondition { + Condition::Known(KnownCondition::EventMatchType(EventMatchTypeCondition { key: Cow::Borrowed("state_key"), - pattern: None, - pattern_type: Some(Cow::Borrowed("user_id")), + pattern_type: Cow::Borrowed(&EventMatchPatternType::UserId), })), ]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION, SOUND_ACTION]), @@ -124,8 +119,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.member")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.member"), }, ))]), actions: Cow::Borrowed(&[Action::DontNotify]), @@ -135,11 +129,10 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ PushRule { rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"), priority_class: 5, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch( - RelatedEventMatchCondition { - key: Some(Cow::Borrowed("sender")), - pattern: None, - pattern_type: Some(Cow::Borrowed("user_id")), + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatchType( + RelatedEventMatchTypeCondition { + key: Cow::Borrowed("sender"), + pattern_type: Cow::Borrowed(&EventMatchPatternType::UserId), rel_type: Cow::Borrowed("m.in_reply_to"), include_fallbacks: None, }, @@ -151,7 +144,12 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ PushRule { rule_id: Cow::Borrowed(".org.matrix.msc3952.is_user_mention"), priority_class: 5, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsUserMention)]), + conditions: Cow::Borrowed(&[Condition::Known( + KnownCondition::ExactEventPropertyContainsType(EventPropertyIsTypeCondition { + key: Cow::Borrowed("content.org.matrix.msc3952.mentions.user_ids"), + value_type: Cow::Borrowed(&EventMatchPatternType::UserId), + }), + )]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), default: true, default_enabled: true, @@ -168,7 +166,10 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mention"), priority_class: 5, conditions: Cow::Borrowed(&[ - Condition::Known(KnownCondition::IsRoomMention), + Condition::Known(KnownCondition::EventPropertyIs(EventPropertyIsCondition { + key: Cow::Borrowed("content.org.matrix.msc3952.mentions.room"), + value: Cow::Borrowed(&SimpleJsonValue::Bool(true)), + })), Condition::Known(KnownCondition::SenderNotificationPermission { key: Cow::Borrowed("room"), }), @@ -186,8 +187,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ }), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("content.body"), - pattern: Some(Cow::Borrowed("@room")), - pattern_type: None, + pattern: Cow::Borrowed("@room"), })), ]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]), @@ -200,13 +200,11 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.tombstone")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.tombstone"), })), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("state_key"), - pattern: Some(Cow::Borrowed("")), - pattern_type: None, + pattern: Cow::Borrowed(""), })), ]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]), @@ -219,8 +217,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.reaction")), - pattern_type: None, + pattern: Cow::Borrowed("m.reaction"), }, ))]), actions: Cow::Borrowed(&[]), @@ -233,13 +230,11 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.server_acl")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.server_acl"), })), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("state_key"), - pattern: Some(Cow::Borrowed("")), - pattern_type: None, + pattern: Cow::Borrowed(""), })), ]), actions: Cow::Borrowed(&[]), @@ -252,8 +247,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.response")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc3381.poll.response"), }, ))]), actions: Cow::Borrowed(&[]), @@ -265,11 +259,10 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { rule_id: Cow::Borrowed("global/content/.m.rule.contains_user_name"), priority_class: 4, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( - EventMatchCondition { + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatchType( + EventMatchTypeCondition { key: Cow::Borrowed("content.body"), - pattern: None, - pattern_type: Some(Cow::Borrowed("user_localpart")), + pattern_type: Cow::Borrowed(&EventMatchPatternType::UserLocalpart), }, ))]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), @@ -284,8 +277,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.call.invite")), - pattern_type: None, + pattern: Cow::Borrowed("m.call.invite"), }, ))]), actions: Cow::Borrowed(&[Action::Notify, RING_ACTION, HIGHLIGHT_FALSE_ACTION]), @@ -298,8 +290,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.message")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.message"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -315,8 +306,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.encrypted")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.encrypted"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -335,8 +325,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("org.matrix.msc1767.encrypted")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc1767.encrypted"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -360,8 +349,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("org.matrix.msc1767.message")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc1767.message"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -385,8 +373,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("org.matrix.msc1767.file")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc1767.file"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -410,8 +397,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("org.matrix.msc1767.image")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc1767.image"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -435,8 +421,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("org.matrix.msc1767.video")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc1767.video"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -460,8 +445,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("org.matrix.msc1767.audio")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc1767.audio"), })), Condition::Known(KnownCondition::RoomMemberCount { is: Some(Cow::Borrowed("2")), @@ -482,8 +466,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.message")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.message"), }, ))]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), @@ -496,8 +479,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("m.room.encrypted")), - pattern_type: None, + pattern: Cow::Borrowed("m.room.encrypted"), }, ))]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), @@ -511,8 +493,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("m.encrypted")), - pattern_type: None, + pattern: Cow::Borrowed("m.encrypted"), })), // MSC3933: Add condition on top of template rule - see MSC. Condition::Known(KnownCondition::RoomVersionSupports { @@ -531,8 +512,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("m.message")), - pattern_type: None, + pattern: Cow::Borrowed("m.message"), })), // MSC3933: Add condition on top of template rule - see MSC. Condition::Known(KnownCondition::RoomVersionSupports { @@ -551,8 +531,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("m.file")), - pattern_type: None, + pattern: Cow::Borrowed("m.file"), })), // MSC3933: Add condition on top of template rule - see MSC. Condition::Known(KnownCondition::RoomVersionSupports { @@ -571,8 +550,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("m.image")), - pattern_type: None, + pattern: Cow::Borrowed("m.image"), })), // MSC3933: Add condition on top of template rule - see MSC. Condition::Known(KnownCondition::RoomVersionSupports { @@ -591,8 +569,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("m.video")), - pattern_type: None, + pattern: Cow::Borrowed("m.video"), })), // MSC3933: Add condition on top of template rule - see MSC. Condition::Known(KnownCondition::RoomVersionSupports { @@ -611,8 +588,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), // MSC3933: Type changed from template rule - see MSC. - pattern: Some(Cow::Borrowed("m.audio")), - pattern_type: None, + pattern: Cow::Borrowed("m.audio"), })), // MSC3933: Add condition on top of template rule - see MSC. Condition::Known(KnownCondition::RoomVersionSupports { @@ -630,18 +606,15 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[ Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("im.vector.modular.widgets")), - pattern_type: None, + pattern: Cow::Borrowed("im.vector.modular.widgets"), })), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("content.type"), - pattern: Some(Cow::Borrowed("jitsi")), - pattern_type: None, + pattern: Cow::Borrowed("jitsi"), })), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("state_key"), - pattern: Some(Cow::Borrowed("*")), - pattern_type: None, + pattern: Cow::Borrowed("*"), })), ]), actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), @@ -657,8 +630,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ }), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.start")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc3381.poll.start"), })), ]), actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), @@ -671,8 +643,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.start")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc3381.poll.start"), }, ))]), actions: Cow::Borrowed(&[Action::Notify]), @@ -688,8 +659,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ }), Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.end")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc3381.poll.end"), })), ]), actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), @@ -702,8 +672,7 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch( EventMatchCondition { key: Cow::Borrowed("type"), - pattern: Some(Cow::Borrowed("org.matrix.msc3381.poll.end")), - pattern_type: None, + pattern: Cow::Borrowed("org.matrix.msc3381.poll.end"), }, ))]), actions: Cow::Borrowed(&[Action::Notify]), diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 2eaa06ad76..67fe6a4823 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeMap, BTreeSet}; +use std::borrow::Cow; +use std::collections::BTreeMap; -use crate::push::JsonValue; use anyhow::{Context, Error}; use lazy_static::lazy_static; use log::warn; @@ -23,9 +23,10 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, - Action, Condition, EventMatchCondition, ExactEventMatchCondition, FilteredPushRules, - KnownCondition, RelatedEventMatchCondition, SimpleJsonValue, + Action, Condition, EventPropertyIsCondition, FilteredPushRules, KnownCondition, + SimpleJsonValue, }; +use crate::push::{EventMatchPatternType, JsonValue}; lazy_static! { /// Used to parse the `is` clause in the room member count condition. @@ -71,10 +72,6 @@ pub struct PushRuleEvaluator { /// True if the event has a mentions property and MSC3952 support is enabled. has_mentions: bool, - /// The user mentions that were part of the message. - user_mentions: BTreeSet<String>, - /// True if the message is a room message. - room_mention: bool, /// The number of users in the room. room_member_count: u64, @@ -100,9 +97,6 @@ pub struct PushRuleEvaluator { /// flag as MSC1767 (extensible events core). msc3931_enabled: bool, - /// If MSC3758 (exact_event_match push rule condition) is enabled. - msc3758_exact_event_match: bool, - /// If MSC3966 (exact_event_property_contains push rule condition) is enabled. msc3966_exact_event_property_contains: bool, } @@ -115,8 +109,6 @@ impl PushRuleEvaluator { pub fn py_new( flattened_keys: BTreeMap<String, JsonValue>, has_mentions: bool, - user_mentions: BTreeSet<String>, - room_mention: bool, room_member_count: u64, sender_power_level: Option<i64>, notification_power_levels: BTreeMap<String, i64>, @@ -124,7 +116,6 @@ impl PushRuleEvaluator { related_event_match_enabled: bool, room_version_feature_flags: Vec<String>, msc3931_enabled: bool, - msc3758_exact_event_match: bool, msc3966_exact_event_property_contains: bool, ) -> Result<Self, Error> { let body = match flattened_keys.get("content.body") { @@ -136,8 +127,6 @@ impl PushRuleEvaluator { flattened_keys, body, has_mentions, - user_mentions, - room_mention, room_member_count, notification_power_levels, sender_power_level, @@ -145,7 +134,6 @@ impl PushRuleEvaluator { related_event_match_enabled, room_version_feature_flags, msc3931_enabled, - msc3758_exact_event_match, msc3966_exact_event_property_contains, }) } @@ -260,26 +248,84 @@ impl PushRuleEvaluator { }; let result = match known_condition { - KnownCondition::EventMatch(event_match) => { - self.match_event_match(event_match, user_id)? - } - KnownCondition::ExactEventMatch(exact_event_match) => { - self.match_exact_event_match(exact_event_match)? + KnownCondition::EventMatch(event_match) => self.match_event_match( + &self.flattened_keys, + &event_match.key, + &event_match.pattern, + )?, + KnownCondition::EventMatchType(event_match) => { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + let pattern = match &*event_match.pattern_type { + EventMatchPatternType::UserId => user_id, + EventMatchPatternType::UserLocalpart => get_localpart_from_id(user_id)?, + }; + + self.match_event_match(&self.flattened_keys, &event_match.key, pattern)? } - KnownCondition::RelatedEventMatch(event_match) => { - self.match_related_event_match(event_match, user_id)? + KnownCondition::EventPropertyIs(event_property_is) => { + self.match_event_property_is(event_property_is)? } - KnownCondition::ExactEventPropertyContains(exact_event_match) => { - self.match_exact_event_property_contains(exact_event_match)? + KnownCondition::RelatedEventMatch(event_match) => self.match_related_event_match( + &event_match.rel_type.clone(), + event_match.include_fallbacks, + event_match.key.clone(), + event_match.pattern.clone(), + )?, + KnownCondition::RelatedEventMatchType(event_match) => { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id + } else { + return Ok(false); + }; + + let pattern = match &*event_match.pattern_type { + EventMatchPatternType::UserId => user_id, + EventMatchPatternType::UserLocalpart => get_localpart_from_id(user_id)?, + }; + + self.match_related_event_match( + &event_match.rel_type.clone(), + event_match.include_fallbacks, + Some(event_match.key.clone()), + Some(Cow::Borrowed(pattern)), + )? } - KnownCondition::IsUserMention => { - if let Some(uid) = user_id { - self.user_mentions.contains(uid) + KnownCondition::ExactEventPropertyContains(event_property_is) => self + .match_exact_event_property_contains( + event_property_is.key.clone(), + event_property_is.value.clone(), + )?, + KnownCondition::ExactEventPropertyContainsType(exact_event_match) => { + // The `pattern_type` can either be "user_id" or "user_localpart", + // either way if we don't have a `user_id` then the condition can't + // match. + let user_id = if let Some(user_id) = user_id { + user_id } else { - false - } + return Ok(false); + }; + + let pattern = match &*exact_event_match.value_type { + EventMatchPatternType::UserId => user_id, + EventMatchPatternType::UserLocalpart => get_localpart_from_id(user_id)?, + }; + + self.match_exact_event_property_contains( + exact_event_match.key.clone(), + Cow::Borrowed(&SimpleJsonValue::Str(pattern.to_string())), + )? } - KnownCondition::IsRoomMention => self.room_mention, KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -330,32 +376,12 @@ impl PushRuleEvaluator { /// Evaluates a `event_match` condition. fn match_event_match( &self, - event_match: &EventMatchCondition, - user_id: Option<&str>, + flattened_event: &BTreeMap<String, JsonValue>, + key: &str, + pattern: &str, ) -> Result<bool, Error> { - let pattern = if let Some(pattern) = &event_match.pattern { - pattern - } else if let Some(pattern_type) = &event_match.pattern_type { - // The `pattern_type` can either be "user_id" or "user_localpart", - // either way if we don't have a `user_id` then the condition can't - // match. - let user_id = if let Some(user_id) = user_id { - user_id - } else { - return Ok(false); - }; - - match &**pattern_type { - "user_id" => user_id, - "user_localpart" => get_localpart_from_id(user_id)?, - _ => return Ok(false), - } - } else { - return Ok(false); - }; - let haystack = if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = - self.flattened_keys.get(&*event_match.key) + flattened_event.get(key) { haystack } else { @@ -364,7 +390,7 @@ impl PushRuleEvaluator { // For the content.body we match against "words", but for everything // else we match against the entire value. - let match_type = if event_match.key == "content.body" { + let match_type = if key == "content.body" { GlobMatchType::Word } else { GlobMatchType::Whole @@ -374,20 +400,15 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } - /// Evaluates a `exact_event_match` condition. (MSC3758) - fn match_exact_event_match( + /// Evaluates a `event_property_is` condition. + fn match_event_property_is( &self, - exact_event_match: &ExactEventMatchCondition, + event_property_is: &EventPropertyIsCondition, ) -> Result<bool, Error> { - // First check if the feature is enabled. - if !self.msc3758_exact_event_match { - return Ok(false); - } - - let value = &exact_event_match.value; + let value = &event_property_is.value; let haystack = if let Some(JsonValue::Value(haystack)) = - self.flattened_keys.get(&*exact_event_match.key) + self.flattened_keys.get(&*event_property_is.key) { haystack } else { @@ -400,8 +421,10 @@ impl PushRuleEvaluator { /// Evaluates a `related_event_match` condition. (MSC3664) fn match_related_event_match( &self, - event_match: &RelatedEventMatchCondition, - user_id: Option<&str>, + rel_type: &str, + include_fallbacks: Option<bool>, + key: Option<Cow<str>>, + pattern: Option<Cow<str>>, ) -> Result<bool, Error> { // First check if related event matching is enabled... if !self.related_event_match_enabled { @@ -409,7 +432,7 @@ impl PushRuleEvaluator { } // get the related event, fail if there is none. - let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) { + let event = if let Some(event) = self.related_events_flattened.get(rel_type) { event } else { return Ok(false); @@ -417,81 +440,38 @@ impl PushRuleEvaluator { // If we are not matching fallbacks, don't match if our special key indicating this is a // fallback relation is not present. - if !event_match.include_fallbacks.unwrap_or(false) - && event.contains_key("im.vector.is_falling_back") - { + if !include_fallbacks.unwrap_or(false) && event.contains_key("im.vector.is_falling_back") { return Ok(false); } - // if we have no key, accept the event as matching, if it existed without matching any - // fields. - let key = if let Some(key) = &event_match.key { - key - } else { - return Ok(true); - }; - - let pattern = if let Some(pattern) = &event_match.pattern { - pattern - } else if let Some(pattern_type) = &event_match.pattern_type { - // The `pattern_type` can either be "user_id" or "user_localpart", - // either way if we don't have a `user_id` then the condition can't - // match. - let user_id = if let Some(user_id) = user_id { - user_id - } else { - return Ok(false); - }; - - match &**pattern_type { - "user_id" => user_id, - "user_localpart" => get_localpart_from_id(user_id)?, - _ => return Ok(false), - } - } else { - return Ok(false); - }; - - let haystack = - if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = event.get(&**key) { - haystack - } else { - return Ok(false); - }; - - // For the content.body we match against "words", but for everything - // else we match against the entire value. - let match_type = if key == "content.body" { - GlobMatchType::Word - } else { - GlobMatchType::Whole - }; - - let mut compiled_pattern = get_glob_matcher(pattern, match_type)?; - compiled_pattern.is_match(haystack) + match (key, pattern) { + // if we have no key, accept the event as matching. + (None, _) => Ok(true), + // There was a key, so we *must* have a pattern to go with it. + (Some(_), None) => Ok(false), + // If there is a key & pattern, check if they're in the flattened event (given by rel_type). + (Some(key), Some(pattern)) => self.match_event_match(event, &key, &pattern), + } } - /// Evaluates a `exact_event_property_contains` condition. (MSC3758) + /// Evaluates a `exact_event_property_contains` condition. (MSC3966) fn match_exact_event_property_contains( &self, - exact_event_match: &ExactEventMatchCondition, + key: Cow<str>, + value: Cow<SimpleJsonValue>, ) -> Result<bool, Error> { // First check if the feature is enabled. if !self.msc3966_exact_event_property_contains { return Ok(false); } - let value = &exact_event_match.value; - - let haystack = if let Some(JsonValue::Array(haystack)) = - self.flattened_keys.get(&*exact_event_match.key) - { + let haystack = if let Some(JsonValue::Array(haystack)) = self.flattened_keys.get(&*key) { haystack } else { return Ok(false); }; - Ok(haystack.contains(&**value)) + Ok(haystack.contains(&value)) } /// Match the member count against an 'is' condition @@ -528,8 +508,6 @@ fn push_rule_evaluator() { let evaluator = PushRuleEvaluator::py_new( flattened_keys, false, - BTreeSet::new(), - false, 10, Some(0), BTreeMap::new(), @@ -538,7 +516,6 @@ fn push_rule_evaluator() { vec![], true, true, - true, ) .unwrap(); @@ -561,8 +538,6 @@ fn test_requires_room_version_supports_condition() { let evaluator = PushRuleEvaluator::py_new( flattened_keys, false, - BTreeSet::new(), - false, 10, Some(0), BTreeMap::new(), @@ -571,7 +546,6 @@ fn test_requires_room_version_supports_condition() { flags, true, true, - true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 253b5f367c..7fde88e825 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -328,16 +328,23 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), - #[serde(rename = "com.beeper.msc3758.exact_event_match")] - ExactEventMatch(ExactEventMatchCondition), + // Identical to event_match but gives predefined patterns. Cannot be added by users. + #[serde(skip_deserializing, rename = "event_match")] + EventMatchType(EventMatchTypeCondition), + EventPropertyIs(EventPropertyIsCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), + // Identical to related_event_match but gives predefined patterns. Cannot be added by users. + #[serde(skip_deserializing, rename = "im.nheko.msc3664.related_event_match")] + RelatedEventMatchType(RelatedEventMatchTypeCondition), #[serde(rename = "org.matrix.msc3966.exact_event_property_contains")] - ExactEventPropertyContains(ExactEventMatchCondition), - #[serde(rename = "org.matrix.msc3952.is_user_mention")] - IsUserMention, - #[serde(rename = "org.matrix.msc3952.is_room_mention")] - IsRoomMention, + ExactEventPropertyContains(EventPropertyIsCondition), + // Identical to exact_event_property_contains but gives predefined patterns. Cannot be added by users. + #[serde( + skip_deserializing, + rename = "org.matrix.msc3966.exact_event_property_contains" + )] + ExactEventPropertyContainsType(EventPropertyIsTypeCondition), ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -364,23 +371,45 @@ impl<'source> FromPyObject<'source> for Condition { } } -/// The body of a [`Condition::EventMatch`] +/// The body of a [`Condition::EventMatch`] with a pattern. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct EventMatchCondition { pub key: Cow<'static, str>, - #[serde(skip_serializing_if = "Option::is_none")] - pub pattern: Option<Cow<'static, str>>, - #[serde(skip_serializing_if = "Option::is_none")] - pub pattern_type: Option<Cow<'static, str>>, + pub pattern: Cow<'static, str>, +} + +#[derive(Serialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum EventMatchPatternType { + UserId, + UserLocalpart, +} + +/// The body of a [`Condition::EventMatch`] that uses user_id or user_localpart as a pattern. +#[derive(Serialize, Debug, Clone)] +pub struct EventMatchTypeCondition { + pub key: Cow<'static, str>, + // During serialization, the pattern_type property gets replaced with a + // pattern property of the correct value in synapse.push.clientformat.format_push_rules_for_user. + pub pattern_type: Cow<'static, EventMatchPatternType>, } -/// The body of a [`Condition::ExactEventMatch`] +/// The body of a [`Condition::EventPropertyIs`] #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ExactEventMatchCondition { +pub struct EventPropertyIsCondition { pub key: Cow<'static, str>, pub value: Cow<'static, SimpleJsonValue>, } +/// The body of a [`Condition::EventPropertyIs`] that uses user_id or user_localpart as a pattern. +#[derive(Serialize, Debug, Clone)] +pub struct EventPropertyIsTypeCondition { + pub key: Cow<'static, str>, + // During serialization, the pattern_type property gets replaced with a + // pattern property of the correct value in synapse.push.clientformat.format_push_rules_for_user. + pub value_type: Cow<'static, EventMatchPatternType>, +} + /// The body of a [`Condition::RelatedEventMatch`] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct RelatedEventMatchCondition { @@ -388,8 +417,18 @@ pub struct RelatedEventMatchCondition { pub key: Option<Cow<'static, str>>, #[serde(skip_serializing_if = "Option::is_none")] pub pattern: Option<Cow<'static, str>>, + pub rel_type: Cow<'static, str>, #[serde(skip_serializing_if = "Option::is_none")] - pub pattern_type: Option<Cow<'static, str>>, + pub include_fallbacks: Option<bool>, +} + +/// The body of a [`Condition::RelatedEventMatch`] that uses user_id or user_localpart as a pattern. +#[derive(Serialize, Debug, Clone)] +pub struct RelatedEventMatchTypeCondition { + // This is only used if pattern_type exists (and thus key must exist), so is + // a bit simpler than RelatedEventMatchCondition. + pub key: Cow<'static, str>, + pub pattern_type: Cow<'static, EventMatchPatternType>, pub rel_type: Cow<'static, str>, #[serde(skip_serializing_if = "Option::is_none")] pub include_fallbacks: Option<bool>, @@ -573,8 +612,7 @@ impl FilteredPushRules { fn test_serialize_condition() { let condition = Condition::Known(KnownCondition::EventMatch(EventMatchCondition { key: "content.body".into(), - pattern: Some("coffee".into()), - pattern_type: None, + pattern: "coffee".into(), })); let json = serde_json::to_string(&condition).unwrap(); @@ -588,7 +626,33 @@ fn test_serialize_condition() { fn test_deserialize_condition() { let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#; - let _: Condition = serde_json::from_str(json).unwrap(); + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::EventMatch(_)) + )); +} + +#[test] +fn test_serialize_event_match_condition_with_pattern_type() { + let condition = Condition::Known(KnownCondition::EventMatchType(EventMatchTypeCondition { + key: "content.body".into(), + pattern_type: Cow::Owned(EventMatchPatternType::UserId), + })); + + let json = serde_json::to_string(&condition).unwrap(); + assert_eq!( + json, + r#"{"kind":"event_match","key":"content.body","pattern_type":"user_id"}"# + ) +} + +#[test] +fn test_cannot_deserialize_event_match_condition_with_pattern_type() { + let json = r#"{"kind":"event_match","key":"content.body","pattern_type":"user_id"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!(condition, Condition::Unknown(_))); } #[test] @@ -603,78 +667,84 @@ fn test_deserialize_unstable_msc3664_condition() { } #[test] -fn test_deserialize_unstable_msc3931_condition() { - let json = - r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#; - - let condition: Condition = serde_json::from_str(json).unwrap(); - assert!(matches!( - condition, - Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }) +fn test_serialize_unstable_msc3664_condition_with_pattern_type() { + let condition = Condition::Known(KnownCondition::RelatedEventMatchType( + RelatedEventMatchTypeCondition { + key: "content.body".into(), + pattern_type: Cow::Owned(EventMatchPatternType::UserId), + rel_type: "m.in_reply_to".into(), + include_fallbacks: Some(true), + }, )); + + let json = serde_json::to_string(&condition).unwrap(); + assert_eq!( + json, + r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern_type":"user_id","rel_type":"m.in_reply_to","include_fallbacks":true}"# + ) } #[test] -fn test_deserialize_unstable_msc3758_condition() { - // A string condition should work. - let json = - r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":"foo"}"#; +fn test_cannot_deserialize_unstable_msc3664_condition_with_pattern_type() { + let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern_type":"user_id","rel_type":"m.in_reply_to"}"#; let condition: Condition = serde_json::from_str(json).unwrap(); + // Since pattern is optional on RelatedEventMatch it deserializes it to that + // instead of RelatedEventMatchType. assert!(matches!( condition, - Condition::Known(KnownCondition::ExactEventMatch(_)) + Condition::Known(KnownCondition::RelatedEventMatch(_)) )); +} - // A boolean condition should work. +#[test] +fn test_deserialize_unstable_msc3931_condition() { let json = - r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":true}"#; + r#"{"kind":"org.matrix.msc3931.room_version_supports","feature":"org.example.feature"}"#; let condition: Condition = serde_json::from_str(json).unwrap(); assert!(matches!( condition, - Condition::Known(KnownCondition::ExactEventMatch(_)) + Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }) )); +} - // An integer condition should work. - let json = r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":1}"#; +#[test] +fn test_deserialize_event_property_is_condition() { + // A string condition should work. + let json = r#"{"kind":"event_property_is","key":"content.value","value":"foo"}"#; let condition: Condition = serde_json::from_str(json).unwrap(); assert!(matches!( condition, - Condition::Known(KnownCondition::ExactEventMatch(_)) + Condition::Known(KnownCondition::EventPropertyIs(_)) )); - // A null condition should work - let json = - r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":null}"#; + // A boolean condition should work. + let json = r#"{"kind":"event_property_is","key":"content.value","value":true}"#; let condition: Condition = serde_json::from_str(json).unwrap(); assert!(matches!( condition, - Condition::Known(KnownCondition::ExactEventMatch(_)) + Condition::Known(KnownCondition::EventPropertyIs(_)) )); -} -#[test] -fn test_deserialize_unstable_msc3952_user_condition() { - let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; + // An integer condition should work. + let json = r#"{"kind":"event_property_is","key":"content.value","value":1}"#; let condition: Condition = serde_json::from_str(json).unwrap(); assert!(matches!( condition, - Condition::Known(KnownCondition::IsUserMention) + Condition::Known(KnownCondition::EventPropertyIs(_)) )); -} -#[test] -fn test_deserialize_unstable_msc3952_room_condition() { - let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; + // A null condition should work + let json = r#"{"kind":"event_property_is","key":"content.value","value":null}"#; let condition: Condition = serde_json::from_str(json).unwrap(); assert!(matches!( condition, - Condition::Known(KnownCondition::IsRoomMention) + Condition::Known(KnownCondition::EventPropertyIs(_)) )); } diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 392c509a8a..9e4ed3246e 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -112,7 +112,7 @@ python3 -m black "${files[@]}" # Catch any common programming mistakes in Python code. # --quiet suppresses the update check. -ruff --quiet "${files[@]}" +ruff --quiet --fix "${files[@]}" # Catch any common programming mistakes in Rust code. # diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index 1fe1a136f1..0e745c0a79 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -29,7 +29,6 @@ _Repr = Callable[[], str] def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ... class SortedList(MutableSequence[_T]): - DEFAULT_LOAD_FACTOR: int = ... def __init__( self, diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 7b33c30cc9..c040944aac 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union from synapse.types import JsonDict, JsonValue @@ -58,8 +58,6 @@ class PushRuleEvaluator: self, flattened_keys: Mapping[str, JsonValue], has_mentions: bool, - user_mentions: Set[str], - room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], @@ -67,7 +65,6 @@ class PushRuleEvaluator: related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, - msc3758_exact_event_match: bool, msc3966_exact_event_property_contains: bool, ): ... def run( diff --git a/synapse/__init__.py b/synapse/__init__.py index fbfd506a43..a203ed533a 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -1,5 +1,6 @@ # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-9 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" This is a reference implementation of a Matrix homeserver. +""" This is an implementation of a Matrix homeserver. """ import json diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py index 819afaaca6..0dd36bee20 100755 --- a/synapse/_scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -37,7 +37,7 @@ import os import shutil import sys -from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.media.filepath import MediaFilePaths logger = logging.getLogger() diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 2b74a40166..19ca399d44 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -47,7 +47,6 @@ def request_registration( _print: Callable[[str], None] = print, exit: Callable[[int], None] = sys.exit, ) -> None: - url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) # Get the nonce @@ -154,7 +153,6 @@ def register_new_user( def main() -> None: - logging.captureWarnings(True) parser = argparse.ArgumentParser( diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5e137dbbf7..2c9cbf8b27 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -94,61 +94,80 @@ reactor = cast(ISynapseReactor, reactor_) logger = logging.getLogger("synapse_port_db") +# SQLite doesn't have a dedicated boolean type (it stores True/False as 1/0). This means +# portdb will read sqlite bools as integers, then try to insert them into postgres +# boolean columns---which fails. Lacking some Python-parseable metaschema, we must +# specify which integer columns should be inserted as booleans into postgres. BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], - "rooms": ["is_public", "has_auth_chain_index"], + "access_tokens": ["used"], + "account_validity": ["email_sent"], + "device_lists_changes_in_room": ["converted_to_destinations"], + "device_lists_outbound_pokes": ["sent"], + "devices": ["hidden"], + "e2e_fallback_keys_json": ["used"], + "e2e_room_keys": ["is_verified"], "event_edges": ["is_state"], + "events": ["processed", "outlier", "contains_url"], + "local_media_repository": ["safe_from_quarantine"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], "public_room_list_stream": ["visibility"], - "devices": ["hidden"], - "device_lists_outbound_pokes": ["sent"], - "users_who_share_rooms": ["share_private"], - "e2e_room_keys": ["is_verified"], - "account_validity": ["email_sent"], + "pushers": ["enabled"], "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], - "local_media_repository": ["safe_from_quarantine"], + "rooms": ["is_public", "has_auth_chain_index"], "users": ["shadow_banned", "approved"], - "e2e_fallback_keys_json": ["used"], - "access_tokens": ["used"], - "device_lists_changes_in_room": ["converted_to_destinations"], - "pushers": ["enabled"], + "un_partial_stated_event_stream": ["rejection_status_changed"], + "users_who_share_rooms": ["share_private"], } +# These tables are never deleted from in normal operation [*], so we can resume porting +# over rows from a previous attempt rather than starting from scratch. +# +# [*]: We do delete from many of these tables when purging a room, and +# presumably when purging old events. So we might e.g. +# +# 1. Run portdb and port half of some table. +# 2. Stop portdb. +# 3. Purge something, deleting some of the rows we've ported over. +# 4. Restart portdb. The rows deleted from sqlite are still present in postgres. +# +# But this isn't the end of the world: we should be able to repeat the purge +# on the postgres DB when porting completes. APPEND_ONLY_TABLES = [ + "cache_invalidation_stream_by_instance", + "event_auth", + "event_edges", + "event_json", "event_reference_hashes", + "event_search", + "event_to_state_groups", "events", - "event_json", - "state_events", - "room_memberships", - "topics", - "room_names", - "rooms", + "ex_outlier_stream", "local_media_repository", "local_media_repository_thumbnails", + "presence_stream", + "public_room_list_stream", + "push_rules_stream", + "received_transactions", + "redactions", + "rejections", "remote_media_cache", "remote_media_cache_thumbnails", - "redactions", - "event_edges", - "event_auth", - "received_transactions", + "room_memberships", + "room_names", + "rooms", "sent_transactions", - "transaction_id_to_pdu", - "users", + "state_events", + "state_group_edges", "state_groups", "state_groups_state", - "event_to_state_groups", - "rejections", - "event_search", - "presence_stream", - "push_rules_stream", - "ex_outlier_stream", - "cache_invalidation_stream_by_instance", - "public_room_list_stream", - "state_group_edges", "stream_ordering_to_exterm", + "topics", + "transaction_id_to_pdu", + "un_partial_stated_event_stream", + "users", ] @@ -1186,7 +1205,6 @@ class CursesProgress(Progress): if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) else: - if self.total_processed > 0: left = float(self.total_remaining) / self.total_processed diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py index b4c96ad7f3..077b90935e 100755 --- a/synapse/_scripts/synctl.py +++ b/synapse/_scripts/synctl.py @@ -167,7 +167,6 @@ Worker = collections.namedtuple( def main() -> None: - parser = argparse.ArgumentParser() parser.add_argument( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a5aa2185a2..28062dd69d 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -213,7 +213,7 @@ def handle_startup_exception(e: Exception) -> NoReturn: def redirect_stdio_to_logs() -> None: streams = [("stdout", LogLevel.info), ("stderr", LogLevel.error)] - for (stream, level) in streams: + for stream, level in streams: oldStream = getattr(sys, stream) loggingFile = LoggingFile( logger=twisted.logger.Logger(namespace=stream), diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index fe7afb9475..b05fe2c589 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -17,7 +17,7 @@ import logging import os import sys import tempfile -from typing import List, Optional +from typing import List, Mapping, Optional from twisted.internet import defer, task @@ -44,6 +44,7 @@ from synapse.storage.databases.main.event_push_actions import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.filtering import FilteringWorkerStore +from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore @@ -86,6 +87,7 @@ class AdminCmdSlavedStore( RegistrationWorkerStore, RoomWorkerStore, ProfileWorkerStore, + MediaRepositoryStore, ): def __init__( self, @@ -149,7 +151,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(events_file, "a") as f: for event in events: - print(json.dumps(event.get_pdu_json()), file=f) + json.dump(event.get_pdu_json(), fp=f) def write_state( self, room_id: str, event_id: str, state: StateMap[EventBase] @@ -162,7 +164,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(event_file, "a") as f: for event in state.values(): - print(json.dumps(event.get_pdu_json()), file=f) + json.dump(event.get_pdu_json(), fp=f) def write_invite( self, room_id: str, event: EventBase, state: StateMap[EventBase] @@ -178,7 +180,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(invite_state, "a") as f: for event in state.values(): - print(json.dumps(event), file=f) + json.dump(event, fp=f) def write_knock( self, room_id: str, event: EventBase, state: StateMap[EventBase] @@ -194,7 +196,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): with open(knock_state, "a") as f: for event in state.values(): - print(json.dumps(event), file=f) + json.dump(event, fp=f) def write_profile(self, profile: JsonDict) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -202,7 +204,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): profile_file = os.path.join(user_directory, "profile") with open(profile_file, "a") as f: - print(json.dumps(profile), file=f) + json.dump(profile, fp=f) def write_devices(self, devices: List[JsonDict]) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -211,7 +213,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): for device in devices: with open(device_file, "a") as f: - print(json.dumps(device), file=f) + json.dump(device, fp=f) def write_connections(self, connections: List[JsonDict]) -> None: user_directory = os.path.join(self.base_directory, "user_data") @@ -220,7 +222,28 @@ class FileExfiltrationWriter(ExfiltrationWriter): for connection in connections: with open(connection_file, "a") as f: - print(json.dumps(connection), file=f) + json.dump(connection, fp=f) + + def write_account_data( + self, file_name: str, account_data: Mapping[str, JsonDict] + ) -> None: + account_data_directory = os.path.join( + self.base_directory, "user_data", "account_data" + ) + os.makedirs(account_data_directory, exist_ok=True) + + account_data_file = os.path.join(account_data_directory, file_name) + + with open(account_data_file, "a") as f: + json.dump(account_data, fp=f) + + def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + file_directory = os.path.join(self.base_directory, "media_ids") + os.makedirs(file_directory, exist_ok=True) + media_id_file = os.path.join(file_directory, media_id) + + with open(media_id_file, "w") as f: + json.dump(media_metadata, fp=f) def finished(self) -> str: return self.base_directory diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index 920538f44d..c8dc3f9d76 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -219,7 +219,7 @@ def main() -> None: # memory space and don't need to repeat the work of loading the code! # Instead of using fork() directly, we use the multiprocessing library, # which uses fork() on Unix platforms. - for (func, worker_args) in zip(worker_functions, args_by_worker): + for func, worker_args in zip(worker_functions, args_by_worker): process = multiprocessing.Process( target=_worker_entrypoint, args=(func, proxy_reactor, worker_args) ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 946f3a3807..0dec24369a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -157,7 +157,6 @@ class GenericWorkerServer(HomeServer): DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore def _listen_http(self, listener_config: ListenerConfig) -> None: - assert listener_config.http_options is not None # We always include a health resource. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6176a70eb2..b8830b1a9c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -321,7 +321,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer: and not config.registration.registrations_require_3pid and not config.registration.registration_requires_token ): - raise ConfigError( "You have enabled open registration without any verification. This is a known vector for " "spam and abuse. If you would like to allow public registration, please consider adding email, " diff --git a/synapse/config/consent.py b/synapse/config/consent.py index be74609dc4..5bfd0cbb71 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py @@ -22,7 +22,6 @@ from ._base import Config class ConsentConfig(Config): - section = "consent" def __init__(self, *args: Any): diff --git a/synapse/config/database.py b/synapse/config/database.py index 928fec8dfe..596d8769fe 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -154,7 +154,6 @@ class DatabaseConfig(Config): logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) def set_databasepath(self, database_path: str) -> None: - if database_path != ":memory:": database_path = self.abspath(database_path) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1d294f8798..489f2601ac 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -166,22 +166,20 @@ class ExperimentalConfig(Config): # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) - # MSC3925: do not replace events with their edits - self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) - - # MSC3758: exact_event_match push rule condition - self.msc3758_exact_event_match = experimental.get( - "msc3758_exact_event_match", False + # MSC3873: Disambiguate event_match keys. + self.msc3873_escape_event_match_key = experimental.get( + "msc3873_escape_event_match_key", False ) - # MSC3873: Disambiguate event_match keys. - self.msc3783_escape_event_match_key = experimental.get( - "msc3783_escape_event_match_key", False + # MSC3966: exact_event_property_contains push rule condition. + self.msc3966_exact_event_property_contains = experimental.get( + "msc3966_exact_event_property_contains", False ) - # MSC3952: Intentional mentions - self.msc3952_intentional_mentions = experimental.get( - "msc3952_intentional_mentions", False + # MSC3952: Intentional mentions, this depends on MSC3966. + self.msc3952_intentional_mentions = ( + experimental.get("msc3952_intentional_mentions", False) + and self.msc3966_exact_event_property_contains ) # MSC3959: Do not generate notifications for edits. @@ -193,3 +191,6 @@ class ExperimentalConfig(Config): self.msc3966_exact_event_property_contains = experimental.get( "msc3966_exact_event_property_contains", False ) + + # MSC3967: Do not require UIA when first uploading cross signing keys + self.msc3967_enabled = experimental.get("msc3967_enabled", False) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4d2b298a70..c205a78039 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -56,7 +56,6 @@ from .workers import WorkerConfig class HomeServerConfig(RootConfig): - config_classes = [ ModulesConfig, ServerConfig, diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 5c13fe428a..a5514e70a2 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -46,7 +46,6 @@ class RatelimitConfig(Config): section = "ratelimiting" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: @@ -87,9 +86,18 @@ class RatelimitConfig(Config): defaults={"per_second": 0.1, "burst_count": 5}, ) + # It is reasonable to login with a bunch of devices at once (i.e. when + # setting up an account), but it is *not* valid to continually be + # logging into new devices. rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {})) - self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {})) + self.rc_login_address = RatelimitSettings( + rc_login_config.get("address", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + self.rc_login_account = RatelimitSettings( + rc_login_config.get("account", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) self.rc_login_failed_attempts = RatelimitSettings( rc_login_config.get("failed_attempts", {}) ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index e4759711ed..ecb3edbe3a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -116,7 +116,6 @@ class ContentRepositoryConfig(Config): section = "media" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( @@ -179,11 +178,13 @@ class ContentRepositoryConfig(Config): for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to # expose FileStorageProviderBackend - if provider_config["module"] == "file_system": - provider_config["module"] = ( - "synapse.rest.media.v1.storage_provider" - ".FileStorageProviderBackend" - ) + if ( + provider_config["module"] == "file_system" + or provider_config["module"] == "synapse.rest.media.v1.storage_provider" + ): + provider_config[ + "module" + ] = "synapse.media.storage_provider.FileStorageProviderBackend" provider_class, parsed_config = load_module( provider_config, ("media_storage_providers", "<item %i>" % i) diff --git a/synapse/config/server.py b/synapse/config/server.py index ecdaa2d9dd..0e46b849cf 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -177,6 +177,7 @@ KNOWN_RESOURCES = { "client", "consent", "federation", + "health", "keys", "media", "metrics", @@ -734,7 +735,6 @@ class ServerConfig(Config): listeners: Optional[List[dict]], **kwargs: Any, ) -> str: - _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 336fe3e0da..318270ebb8 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -30,7 +30,6 @@ class TlsConfig(Config): section = "tls" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 86cd4af9bd..d710607c63 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -399,7 +399,7 @@ class Keyring: # We now convert the returned list of results into a map from server # name to key ID to FetchKeyResult, to return. to_return: Dict[str, Dict[str, FetchKeyResult]] = {} - for (request, results) in zip(deduped_requests, results_per_request): + for request, results in zip(deduped_requests, results_per_request): to_return_by_server = to_return.setdefault(request.server_name, {}) for key_id, key_result in results.items(): existing = to_return_by_server.get(key_id) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e0d82ad81c..a91a5d1e3c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers + from synapse.storage.databases import StateGroupDataStore from synapse.storage.databases.main import DataStore from synapse.types.state import StateFilter @@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase): partial_state: bool state_map_before_event: Optional[StateMap[str]] = None + @classmethod + async def batch_persist_unpersisted_contexts( + cls, + events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]], + room_id: str, + last_known_state_group: int, + datastore: "StateGroupDataStore", + ) -> List[Tuple[EventBase, EventContext]]: + """ + Takes a list of events and their associated unpersisted contexts and persists + the unpersisted contexts, returning a list of events and persisted contexts. + Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: A list of events and their unpersisted contexts + room_id: the room_id for the events + last_known_state_group: the last persisted state group + datastore: a state datastore + """ + amended_events_and_context = await datastore.store_state_deltas_for_batched( + events_and_context, room_id, last_known_state_group + ) + + events_and_persisted_context = [] + for event, unpersisted_context in amended_events_and_context: + if event.is_state(): + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, + ) + else: + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_before_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, + ) + events_and_persisted_context.append((event, context)) + return events_and_persisted_context + async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 623a2c71ea..765c15bb51 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -33,8 +33,8 @@ from typing_extensions import Literal import synapse from synapse.api.errors import Codes from synapse.logging.opentracing import trace -from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.media_storage import ReadableFileWrapper +from synapse.media._base import FileInfo +from synapse.media.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import JsonDict, RoomAlias, UserProfile from synapse.util.async_helpers import delay_cancellation, maybe_awaitable diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index e7532a568a..1b7c6de974 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -49,6 +49,8 @@ CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable] +ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] +ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -82,7 +84,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: # correctly, we need to await its result. Therefore it doesn't make a lot of # sense to make it go through the run() wrapper. if f.__name__ == "check_event_allowed": - # We need to wrap check_event_allowed because its old form would return either # a boolean or a dict, but now we want to return the dict separately from the # boolean. @@ -104,7 +105,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: return wrap_check_event_allowed if f.__name__ == "on_create_room": - # We need to wrap on_create_room because its old form would return a boolean # if the room creation is denied, but now we just want it to raise an # exception. @@ -181,6 +181,12 @@ class ThirdPartyEventRules: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = [] self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = [] + self._on_add_user_third_party_identifier_callbacks: List[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = [] + self._on_remove_user_third_party_identifier_callbacks: List[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = [] def register_third_party_rules_callbacks( self, @@ -200,6 +206,12 @@ class ThirdPartyEventRules: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, + on_add_user_third_party_identifier: Optional[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, + on_remove_user_third_party_identifier: Optional[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -237,6 +249,11 @@ class ThirdPartyEventRules: if on_threepid_bind is not None: self._on_threepid_bind_callbacks.append(on_threepid_bind) + if on_add_user_third_party_identifier is not None: + self._on_add_user_third_party_identifier_callbacks.append( + on_add_user_third_party_identifier + ) + async def check_event_allowed( self, event: EventBase, @@ -552,6 +569,9 @@ class ThirdPartyEventRules: local homeserver, not when it's created on an identity server (and then kept track of so that it can be unbound on the same IS later on). + THIS MODULE CALLBACK METHOD HAS BEEN DEPRECATED. Please use the + `on_add_user_third_party_identifier` callback method instead. + Args: user_id: the user being associated with the threepid. medium: the threepid's medium. @@ -564,3 +584,44 @@ class ThirdPartyEventRules: logger.exception( "Failed to run module API callback %s: %s", callback, e ) + + async def on_add_user_third_party_identifier( + self, user_id: str, medium: str, address: str + ) -> None: + """Called when an association between a user's Matrix ID and a third-party ID + (email, phone number) has successfully been registered on the homeserver. + + Args: + user_id: The User ID included in the association. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + """ + for callback in self._on_add_user_third_party_identifier_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + + async def on_remove_user_third_party_identifier( + self, user_id: str, medium: str, address: str + ) -> None: + """Called when an association between a user's Matrix ID and a third-party ID + (email, phone number) has been successfully removed on the homeserver. + + This is called *after* any known bindings on identity servers for this + association have been removed. + + Args: + user_id: The User ID included in the removed association. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + """ + for callback in self._on_remove_user_third_party_identifier_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ebf8c7ed83..b9c15ffcdb 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,8 +38,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion -from synapse.types import JsonDict -from synapse.util.frozenutils import unfreeze +from synapse.types import JsonDict, Requester from . import EventBase @@ -317,8 +316,9 @@ class SerializeEventConfig: as_client_event: bool = True # Function to convert from federation format to client format event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 - # ID of the user's auth token - used for namespacing of transaction IDs - token_id: Optional[int] = None + # The entity that requested the event. This is used to determine whether to include + # the transaction_id in the unsigned section of the event. + requester: Optional[Requester] = None # List of event fields to include. If empty, all fields will be returned. only_event_fields: Optional[List[str]] = None # Some events can have stripped room state stored in the `unsigned` field. @@ -368,11 +368,24 @@ def serialize_event( e.unsigned["redacted_because"], time_now_ms, config=config ) - if config.token_id is not None: - if config.token_id == getattr(e.internal_metadata, "token_id", None): - txn_id = getattr(e.internal_metadata, "txn_id", None) - if txn_id is not None: - d["unsigned"]["transaction_id"] = txn_id + # If we have a txn_id saved in the internal_metadata, we should include it in the + # unsigned section of the event if it was sent by the same session as the one + # requesting the event. + # There is a special case for guests, because they only have one access token + # without associated access_token_id, so we always include the txn_id for events + # they sent. + txn_id = getattr(e.internal_metadata, "txn_id", None) + if txn_id is not None and config.requester is not None: + event_token_id = getattr(e.internal_metadata, "token_id", None) + if config.requester.user.to_string() == e.sender and ( + ( + event_token_id is not None + and config.requester.access_token_id is not None + and event_token_id == config.requester.access_token_id + ) + or config.requester.is_guest + ): + d["unsigned"]["transaction_id"] = txn_id # invite_room_state and knock_room_state are a list of stripped room state events # that are meant to provide metadata about a room to an invitee/knocker. They are @@ -403,14 +416,6 @@ class EventClientSerializer: clients. """ - def __init__(self, inhibit_replacement_via_edits: bool = False): - """ - Args: - inhibit_replacement_via_edits: If this is set to True, then events are - never replaced by their edits. - """ - self._inhibit_replacement_via_edits = inhibit_replacement_via_edits - def serialize_event( self, event: Union[JsonDict, EventBase], @@ -418,7 +423,6 @@ class EventClientSerializer: *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, - apply_edits: bool = True, ) -> JsonDict: """Serializes a single event. @@ -428,10 +432,7 @@ class EventClientSerializer: config: Event serialization config bundle_aggregations: A map from event_id to the aggregations to be bundled into the event. - apply_edits: Whether the content of the event should be modified to reflect - any replacement in `bundle_aggregations[<event_id>].replace`. - See also the `inhibit_replacement_via_edits` constructor arg: if that is - set to True, then this argument is ignored. + Returns: The serialized event """ @@ -450,38 +451,10 @@ class EventClientSerializer: config, bundle_aggregations, serialized_event, - apply_edits=apply_edits, ) return serialized_event - def _apply_edit( - self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase - ) -> None: - """Replace the content, preserving existing relations of the serialized event. - - Args: - orig_event: The original event. - serialized_event: The original event, serialized. This is modified. - edit: The event which edits the above. - """ - - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relates_to = orig_event.content.get("m.relates_to") - if relates_to: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relates_to.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - def _inject_bundled_aggregations( self, event: EventBase, @@ -489,7 +462,6 @@ class EventClientSerializer: config: SerializeEventConfig, bundled_aggregations: Dict[str, "BundledAggregations"], serialized_event: JsonDict, - apply_edits: bool, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. @@ -504,9 +476,6 @@ class EventClientSerializer: While serializing the bundled aggregations this map may be searched again for additional events in a recursive manner. serialized_event: The serialized event which may be modified. - apply_edits: Whether the content of the event should be modified to reflect - any replacement in `aggregations.replace` (subject to the - `inhibit_replacement_via_edits` constructor arg). """ # We have already checked that aggregations exist for this event. @@ -516,22 +485,12 @@ class EventClientSerializer: # being serialized. serialized_aggregations = {} - if event_aggregations.annotations: - serialized_aggregations[ - RelationTypes.ANNOTATION - ] = event_aggregations.annotations - if event_aggregations.references: serialized_aggregations[ RelationTypes.REFERENCE ] = event_aggregations.references if event_aggregations.replace: - # If there is an edit, optionally apply it to the event. - edit = event_aggregations.replace - if apply_edits and not self._inhibit_replacement_via_edits: - self._apply_edit(event, serialized_event, edit) - # Include information about it in the relations dict. # # Matrix spec v1.5 (https://spec.matrix.org/v1.5/client-server-api/#server-side-aggregation-of-mreplace-relationships) @@ -539,10 +498,7 @@ class EventClientSerializer: # `sender` of the edit; however MSC3925 proposes extending it to the whole # of the edit, which is what we do here. serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event( - edit, - time_now, - config=config, - apply_edits=False, + event_aggregations.replace, time_now, config=config ) # Include any threaded replies to this event. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index d720b5fd3f..3063df7990 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -314,7 +314,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): # stream position. keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} - for ((destination, edu_key), pos) in keyed_edus.items(): + for (destination, edu_key), pos in keyed_edus.items(): rows.append( ( pos, @@ -329,7 +329,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): j = self.edus.bisect_right(to_token) + 1 edus = self.edus.items()[i:j] - for (pos, edu) in edus: + for pos, edu in edus: rows.append((pos, EduRow(edu))) # Sort rows based on pos diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 797de46dbc..7e01c18c6c 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -155,9 +155,6 @@ class AccountDataHandler: max_stream_id = await self._store.remove_account_data_for_room( user_id, room_id, account_data_type ) - if max_stream_id is None: - # The referenced account data did not exist, so no delete occurred. - return None self._notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] @@ -230,9 +227,6 @@ class AccountDataHandler: max_stream_id = await self._store.remove_account_data_for_user( user_id, account_data_type ) - if max_stream_id is None: - # The referenced account data did not exist, so no delete occurred. - return None self._notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] @@ -248,7 +242,6 @@ class AccountDataHandler: instance_name=random.choice(self._account_data_writers), user_id=user_id, account_data_type=account_data_type, - content={}, ) return response["max_stream_id"] diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index b03c214b14..b06f25b03c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,7 +14,7 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set from synapse.api.constants import Direction, Membership from synapse.events import EventBase @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main + self._store = hs.get_datastores().main self._device_handler = hs.get_device_handler() self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -38,7 +38,7 @@ class AdminHandler: async def get_whois(self, user: UserID) -> JsonDict: connections = [] - sessions = await self.store.get_user_ip_and_agents(user) + sessions = await self._store.get_user_ip_and_agents(user) for session in sessions: connections.append( { @@ -57,7 +57,7 @@ class AdminHandler: async def get_user(self, user: UserID) -> Optional[JsonDict]: """Function to get user details""" - user_info_dict = await self.store.get_user_by_id(user.to_string()) + user_info_dict = await self._store.get_user_by_id(user.to_string()) if user_info_dict is None: return None @@ -89,11 +89,11 @@ class AdminHandler: } # Add additional user metadata - profile = await self.store.get_profileinfo(user.localpart) - threepids = await self.store.user_get_threepids(user.to_string()) + profile = await self._store.get_profileinfo(user.localpart) + threepids = await self._store.user_get_threepids(user.to_string()) external_ids = [ ({"auth_provider": auth_provider, "external_id": external_id}) - for auth_provider, external_id in await self.store.get_external_ids_by_user( + for auth_provider, external_id in await self._store.get_external_ids_by_user( user.to_string() ) ] @@ -101,7 +101,7 @@ class AdminHandler: user_info_dict["avatar_url"] = profile.avatar_url user_info_dict["threepids"] = threepids user_info_dict["external_ids"] = external_ids - user_info_dict["erased"] = await self.store.is_user_erased(user.to_string()) + user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) return user_info_dict @@ -117,7 +117,7 @@ class AdminHandler: The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = await self.store.get_rooms_for_local_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -131,7 +131,7 @@ class AdminHandler: # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle # those separately. - rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) + rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): room_id = room.room_id @@ -140,7 +140,7 @@ class AdminHandler: "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) ) - forgotten = await self.store.did_forget(user_id, room_id) + forgotten = await self._store.did_forget(user_id, room_id) if forgotten: logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue @@ -152,14 +152,14 @@ class AdminHandler: if room.membership == Membership.INVITE: event_id = room.event_id - invite = await self.store.get_event(event_id, allow_none=True) + invite = await self._store.get_event(event_id, allow_none=True) if invite: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) if room.membership == Membership.KNOCK: event_id = room.event_id - knock = await self.store.get_event(event_id, allow_none=True) + knock = await self._store.get_event(event_id, allow_none=True) if knock: knock_state = knock.unsigned["knock_room_state"] writer.write_knock(room_id, knock, knock_state) @@ -170,7 +170,7 @@ class AdminHandler: # were joined. We estimate that point by looking at the # stream_ordering of the last membership if it wasn't a join. if room.membership == Membership.JOIN: - stream_ordering = self.store.get_room_max_stream_ordering() + stream_ordering = self._store.get_room_max_stream_ordering() else: stream_ordering = room.stream_ordering @@ -197,7 +197,7 @@ class AdminHandler: # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = await self.store.paginate_room_events( + events, _ = await self._store.paginate_room_events( room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: @@ -252,16 +252,49 @@ class AdminHandler: profile = await self.get_user(UserID.from_string(user_id)) if profile is not None: writer.write_profile(profile) + logger.info("[%s] Written profile", user_id) # Get all devices the user has devices = await self._device_handler.get_devices_by_user(user_id) writer.write_devices(devices) + logger.info("[%s] Written %s devices", user_id, len(devices)) # Get all connections the user has connections = await self.get_whois(UserID.from_string(user_id)) writer.write_connections( connections["devices"][""]["sessions"][0]["connections"] ) + logger.info("[%s] Written %s connections", user_id, len(connections)) + + # Get all account data the user has global and in rooms + global_data = await self._store.get_global_account_data_for_user(user_id) + by_room_data = await self._store.get_room_account_data_for_user(user_id) + writer.write_account_data("global", global_data) + for room_id in by_room_data: + writer.write_account_data(room_id, by_room_data[room_id]) + logger.info( + "[%s] Written account data for %s rooms", user_id, len(by_room_data) + ) + + # Get all media ids the user has + limit = 100 + start = 0 + while True: + media_ids, total = await self._store.get_local_media_by_user_paginate( + start, limit, user_id + ) + for media in media_ids: + writer.write_media_id(media["media_id"], media) + + logger.info( + "[%s] Written %d media_ids of %s", + user_id, + (start + len(media_ids)), + total, + ) + if (start + limit) >= total: + break + start += limit return writer.finished() @@ -341,6 +374,30 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + def write_account_data( + self, file_name: str, account_data: Mapping[str, JsonDict] + ) -> None: + """Write the account data of a user. + + Args: + file_name: file name to write data + account_data: mapping of global or room account_data + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + """Write the media's metadata of a user. + Exports only the metadata, as this can be fetched from the database via + read only. In order to access the files, a connection to the correct + media repository would be required. + + Args: + media_id: ID of the media. + media_metadata: Metadata of one media file. + """ + + @abc.abstractmethod def finished(self) -> Any: """Called when all data has successfully been exported and written. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5d1d21cdc8..ec3ab968e9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -737,7 +737,7 @@ class ApplicationServicesHandler: ) ret = [] - for (success, result) in results: + for success, result in results: if success: ret.extend(result) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 57a6854b1e..308e38edea 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -201,7 +201,7 @@ class AuthHandler: for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): - self.checkers[inst.AUTH_TYPE] = inst # type: ignore + self.checkers[inst.AUTH_TYPE] = inst self.bcrypt_rounds = hs.config.registration.bcrypt_rounds @@ -815,7 +815,6 @@ class AuthHandler: now_ms = self._clock.time_msec() if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: - raise SynapseError( HTTPStatus.FORBIDDEN, "The supplied refresh token has expired", @@ -1543,6 +1542,17 @@ class AuthHandler: async def add_threepid( self, user_id: str, medium: str, address: str, validated_at: int ) -> None: + """ + Adds an association between a user's Matrix ID and a third-party ID (email, + phone number). + + Args: + user_id: The ID of the user to associate. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + validated_at: The timestamp in ms of when the validation that the user owns + this third-party ID occurred. + """ # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -1567,42 +1577,44 @@ class AuthHandler: user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) + # Inform Synapse modules that a 3PID association has been created. + await self._third_party_rules.on_add_user_third_party_identifier( + user_id, medium, address + ) + + # Deprecated method for informing Synapse modules that a 3PID association + # has successfully been created. await self._third_party_rules.on_threepid_bind(user_id, medium, address) - async def delete_threepid( - self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ) -> bool: - """Attempts to unbind the 3pid on the identity servers and deletes it - from the local database. + async def delete_local_threepid( + self, user_id: str, medium: str, address: str + ) -> None: + """Deletes an association between a third-party ID and a user ID from the local + database. This method does not unbind the association from any identity servers. + + If `medium` is 'email' and a pusher is associated with this third-party ID, the + pusher will also be deleted. Args: user_id: ID of user to remove the 3pid from. medium: The medium of the 3pid being removed: "email" or "msisdn". address: The 3pid address to remove. - id_server: Use the given identity server when unbinding - any threepids. If None then will attempt to unbind using the - identity server specified when binding (if known). - - Returns: - Returns True if successfully unbound the 3pid on - the identity server, False if identity server doesn't support the - unbind API. """ - # 'Canonicalise' email addresses as per above if medium == "email": address = canonicalise_email(address) - result = await self.hs.get_identity_handler().try_unbind_threepid( - user_id, medium, address, id_server + await self.store.user_delete_threepid(user_id, medium, address) + + # Inform Synapse modules that a 3PID association has been deleted. + await self._third_party_rules.on_remove_user_third_party_identifier( + user_id, medium, address ) - await self.store.user_delete_threepid(user_id, medium, address) if medium == "email": await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id="m.email", pushkey=address, user_id=user_id ) - return result async def hash(self, password: str) -> str: """Computes a secure hash of password. @@ -2259,7 +2271,6 @@ class PasswordAuthProvider: async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - # call all of the on_logged_out callbacks for callback in self.on_logged_out_callbacks: try: diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index d24f649382..d31263c717 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -100,26 +100,28 @@ class DeactivateAccountHandler: # unbinding identity_server_supports_unbinding = True - # Retrieve the 3PIDs this user has bound to an identity server - threepids = await self.store.user_get_bound_threepids(user_id) - - for threepid in threepids: + # Attempt to unbind any known bound threepids to this account from identity + # server(s). + bound_threepids = await self.store.user_get_bound_threepids(user_id) + for threepid in bound_threepids: try: result = await self._identity_handler.try_unbind_threepid( user_id, threepid["medium"], threepid["address"], id_server ) - identity_server_supports_unbinding &= result except Exception: # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") - await self.store.user_delete_threepid( + + identity_server_supports_unbinding &= result + + # Remove any local threepid associations for this account. + local_threepids = await self.store.user_get_threepids(user_id) + for threepid in local_threepids: + await self._auth_handler.delete_local_threepid( user_id, threepid["medium"], threepid["address"] ) - # Remove all 3PIDs this user has bound to the homeserver - await self.store.user_delete_threepids(user_id) - # delete any devices belonging to the user, which will also # delete corresponding access tokens. await self._device_handler.delete_all_devices_for_user(user_id) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a5798e9483..1fb23cc9bf 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -497,9 +497,11 @@ class DirectoryHandler: raise SynapseError(403, "Not allowed to publish room") # Check if publishing is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 43cbece21b..4e9c8d8db0 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1301,6 +1301,20 @@ class E2eKeysHandler: return desired_key_data + async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool: + """Checks if the user has cross-signing set up + + Args: + user_id: The user to check + + Returns: + True if the user has cross-signing set up, False otherwise + """ + existing_master_key = await self.store.get_e2e_cross_signing_key( + user_id, "master" + ) + return existing_master_key is not None + def _check_cross_signing_key( key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 83f53ceb88..50317ec753 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -188,7 +188,6 @@ class E2eRoomKeysHandler: # XXX: perhaps we should use a finer grained lock here? async with self._upload_linearizer.queue(user_id): - # Check that the version we're trying to upload is the current version try: version_info = await self.store.get_e2e_room_keys_version_info(user_id) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 46dd63c3f0..c508861b6a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -236,7 +236,6 @@ class EventAuthHandler: # in any of them. allowed_rooms = await self.get_rooms_that_allow_join(state_ids) if not await self.is_user_in_rooms(allowed_rooms, user_id): - # If this is a remote request, the user might be in an allowed room # that we do not know about. if get_domain_from_id(user_id) != self._server_name: diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 949b69cb41..68c07f0265 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -23,7 +23,7 @@ from synapse.events.utils import SerializeEventConfig from synapse.handlers.presence import format_user_presence_state from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -46,13 +46,12 @@ class EventStreamHandler: async def get_stream( self, - auth_user_id: str, + requester: Requester, pagin_config: PaginationConfig, timeout: int = 0, as_client_event: bool = True, affect_presence: bool = True, room_id: Optional[str] = None, - is_guest: bool = False, ) -> JsonDict: """Fetches the events stream for a given user.""" @@ -62,13 +61,12 @@ class EventStreamHandler: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - await self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(requester.user.to_string()) - auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() context = await presence_handler.user_syncing( - auth_user_id, + requester.user.to_string(), affect_presence=affect_presence, presence_state=PresenceState.ONLINE, ) @@ -82,10 +80,10 @@ class EventStreamHandler: timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) stream_result = await self.notifier.get_events_for( - auth_user, + requester.user, pagin_config, timeout, - is_guest=is_guest, + is_guest=requester.is_guest, explicit_room_id=room_id, ) events = stream_result.events @@ -102,7 +100,7 @@ class EventStreamHandler: if event.membership != Membership.JOIN: continue # Send down presence. - if event.state_key == auth_user_id: + if event.state_key == requester.user.to_string(): # Send down presence for everyone in the room. users: Iterable[str] = await self.store.get_users_in_room( event.room_id @@ -124,7 +122,9 @@ class EventStreamHandler: chunks = self._event_serializer.serialize_events( events, time_now, - config=SerializeEventConfig(as_client_event=as_client_event), + config=SerializeEventConfig( + as_client_event=as_client_event, requester=requester + ), ) chunk = { diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5371336529..50f8041f17 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -952,7 +952,20 @@ class FederationHandler: # # Note that this requires the /send_join request to come back to the # same server. + prev_event_ids = None if room_version.msc3083_join_rules: + # Note that the room's state can change out from under us and render our + # nice join rules-conformant event non-conformant by the time we build the + # event. When this happens, our validation at the end fails and we respond + # to the requesting server with a 403, which is misleading — it indicates + # that the user is not allowed to join the room and the joining server + # should not bother retrying via this homeserver or any others, when + # in fact we've just messed up with building the event. + # + # To reduce the likelihood of this race, we capture the forward extremities + # of the room (prev_event_ids) just before fetching the current state, and + # hope that the state we fetch corresponds to the prev events we chose. + prev_event_ids = await self.store.get_prev_events_for_room(room_id) state_ids = await self._state_storage_controller.get_current_state_ids( room_id ) @@ -995,7 +1008,8 @@ class FederationHandler: unpersisted_context, _, ) = await self.event_creation_handler.create_new_client_event( - builder=builder + builder=builder, + prev_event_ids=prev_event_ids, ) except SynapseError as e: logger.warning("Failed to create join to %s because %s", room_id, e) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 1a29abde98..b3be7a86f0 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -124,7 +124,6 @@ class InitialSyncHandler: as_client_event: bool = True, include_archived: bool = False, ) -> JsonDict: - memberships = [Membership.INVITE, Membership.JOIN] if include_archived: memberships.append(Membership.LEAVE) @@ -319,11 +318,9 @@ class InitialSyncHandler: ) is_peeking = member_event_id is None - user_id = requester.user.to_string() - if membership == Membership.JOIN: result = await self._room_initial_sync_joined( - user_id, room_id, pagin_config, membership, is_peeking + requester, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: # The member_event_id will always be available if membership is set @@ -331,10 +328,16 @@ class InitialSyncHandler: assert member_event_id result = await self._room_initial_sync_parted( - user_id, room_id, pagin_config, membership, member_event_id, is_peeking + requester, + room_id, + pagin_config, + membership, + member_event_id, + is_peeking, ) account_data_events = [] + user_id = requester.user.to_string() tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append( @@ -351,7 +354,7 @@ class InitialSyncHandler: async def _room_initial_sync_parted( self, - user_id: str, + requester: Requester, room_id: str, pagin_config: PaginationConfig, membership: str, @@ -370,13 +373,17 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self._storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, + requester.user.to_string(), + messages, + is_peeking=is_peeking, ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token) time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) return { "membership": membership, @@ -384,14 +391,18 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events( + messages, time_now, config=serialize_options + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(room_state.values(), time_now) + self._event_serializer.serialize_events( + room_state.values(), time_now, config=serialize_options + ) ), "presence": [], "receipts": [], @@ -399,7 +410,7 @@ class InitialSyncHandler: async def _room_initial_sync_joined( self, - user_id: str, + requester: Requester, room_id: str, pagin_config: PaginationConfig, membership: str, @@ -411,9 +422,12 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) # Don't bundle aggregations as this is a deprecated API. state = self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), + time_now, + config=serialize_options, ) now_token = self.hs.get_event_sources().get_current_token() @@ -451,7 +465,10 @@ class InitialSyncHandler: if not receipts: return [] - return ReceiptEventSource.filter_out_private_receipts(receipts, user_id) + return ReceiptEventSource.filter_out_private_receipts( + receipts, + requester.user.to_string(), + ) presence, receipts, (messages, token) = await make_deferred_yieldable( gather_results( @@ -470,20 +487,23 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self._storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, + requester.user.to_string(), + messages, + is_peeking=is_peeking, ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) end_token = now_token - time_now = self.clock.time_msec() - ret = { "room_id": room_id, "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events( + messages, time_now, config=serialize_options + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 77d92f1574..d283a938c0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,6 +16,7 @@ # limitations under the License. import logging import random +from builtins import dict from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple @@ -50,7 +51,7 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext, UnpersistedEventContextBase -from synapse.events.utils import maybe_upsert_event_field +from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.logging import opentracing @@ -245,8 +246,11 @@ class MessageHandler: ) room_state = room_state_events[membership_event_id] - now = self.clock.time_msec() - events = self._event_serializer.serialize_events(room_state.values(), now) + events = self._event_serializer.serialize_events( + room_state.values(), + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return events async def _user_can_see_state_at_event( @@ -574,7 +578,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext, Optional[dict]]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -723,8 +727,6 @@ class EventCreationHandler: current_state_group=current_state_group, ) - context = await unpersisted_context.persist(event) - # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -741,7 +743,7 @@ class EventCreationHandler: assert state_map is not None prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: - prev_state_ids = await context.get_prev_state_ids( + prev_state_ids = await unpersisted_context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) @@ -766,8 +768,7 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) - - return event, context, new_event + return event, unpersisted_context, new_event async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1007,7 +1008,11 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context, third_party_event_dict = await self.create_event( + ( + event, + unpersisted_context, + third_party_event_dict, + ) = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1018,6 +1023,7 @@ class EventCreationHandler: historical=historical, depth=depth, ) + context = await unpersisted_context.persist(event) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( event.sender, @@ -1209,7 +1215,6 @@ class EventCreationHandler: if for_batch: assert prev_event_ids is not None assert state_map is not None - assert current_state_group is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth @@ -2067,7 +2072,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context, _ = await self.create_event( + event, unpersisted_context, _ = await self.create_event( requester, { "type": EventTypes.Dummy, @@ -2076,6 +2081,7 @@ class EventCreationHandler: "sender": user_id, }, ) + context = await unpersisted_context.persist(event) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index ceefa16b49..8c79c055ba 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -579,7 +579,9 @@ class PaginationHandler: time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig(as_client_event=as_client_event) + serialize_options = SerializeEventConfig( + as_client_event=as_client_event, requester=requester + ) chunk = { "chunk": ( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 87af31aa27..4ad2233573 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler): ) if self.unpersisted_users_changes: - await self.store.update_presence( [ self.user_to_current_state[user_id] @@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler): now = self.clock.time_msec() with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c611efb760..e4e506e62c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -476,7 +476,7 @@ class RegistrationHandler: # create room expects the localpart of the room alias config["room_alias_name"] = room_alias.localpart - info, _ = await room_creation_handler.create_room( + room_id, _, _ = await room_creation_handler.create_room( fake_requester, config=config, ratelimit=False, @@ -490,7 +490,7 @@ class RegistrationHandler: user_id, authenticated_entity=self._server_name ), target=UserID.from_string(user_id), - room_id=info["room_id"], + room_id=room_id, # Since it was just created, there are no remote hosts. remote_room_hosts=[], action="join", diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0fb15391e0..1d09fdf135 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -20,6 +20,7 @@ import attr from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event +from synapse.events.utils import SerializeEventConfig from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent @@ -60,13 +61,12 @@ class BundledAggregations: Some values require additional processing during serialization. """ - annotations: Optional[JsonDict] = None references: Optional[JsonDict] = None replace: Optional[EventBase] = None thread: Optional[_ThreadAggregation] = None def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) + return bool(self.references or self.replace or self.thread) class RelationsHandler: @@ -152,16 +152,23 @@ class RelationsHandler: ) now = self._clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) return_value: JsonDict = { "chunk": self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations + events, + now, + bundle_aggregations=aggregations, + config=serialize_options, ), } if include_original_event: # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. return_value["original_event"] = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None + event, + now, + bundle_aggregations=None, + config=serialize_options, ) if next_token: @@ -227,67 +234,6 @@ class RelationsHandler: e.msg, ) - async def get_annotations_for_events( - self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() - ) -> Dict[str, List[JsonDict]]: - """Get a list of annotations to the given events, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happened - on an event. - - Args: - event_ids: Fetch events that relate to these event IDs. - ignored_users: The users ignored by the requesting user. - - Returns: - A map of event IDs to a list of groups of annotations that match. - Each entry is a dict with `type`, `key` and `count` fields. - """ - # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_events( - event_ids - ) - - # Avoid additional logic if there are no ignored users. - if not ignored_users: - return { - event_id: results - for event_id, results in full_results.items() - if results - } - - # Then subtract off the results for any ignored users. - ignored_results = await self._main_store.get_aggregation_groups_for_users( - [event_id for event_id, results in full_results.items() if results], - ignored_users, - ) - - filtered_results = {} - for event_id, results in full_results.items(): - # If no annotations, skip. - if not results: - continue - - # If there are not ignored results for this event, copy verbatim. - if event_id not in ignored_results: - filtered_results[event_id] = results - continue - - # Otherwise, subtract out the ignored results. - event_ignored_results = ignored_results[event_id] - for result in results: - key = (result["type"], result["key"]) - if key in event_ignored_results: - # Ensure to not modify the cache. - result = result.copy() - result["count"] -= event_ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.setdefault(event_id, []).append(result) - - return filtered_results - async def get_references_for_events( self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() ) -> Dict[str, List[_RelatedEvent]]: @@ -531,17 +477,6 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - async def _fetch_annotations() -> None: - """Fetch any annotations (ie, reactions) to bundle with this event.""" - annotations_by_event_id = await self.get_annotations_for_events( - events_by_id.keys(), ignored_users=ignored_users - ) - for event_id, annotations in annotations_by_event_id.items(): - if annotations: - results.setdefault(event_id, BundledAggregations()).annotations = { - "chunk": annotations - } - async def _fetch_references() -> None: """Fetch any references to bundle with this event.""" references_by_event_id = await self.get_references_for_events( @@ -575,7 +510,6 @@ class RelationsHandler: await make_deferred_yieldable( gather_results( ( - run_in_background(_fetch_annotations), run_in_background(_fetch_references), run_in_background(_fetch_edits), ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9fb7209549..c70afa3176 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,6 +51,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM @@ -211,7 +212,7 @@ class RoomCreationHandler: # the required power level to send the tombstone event. ( tombstone_event, - tombstone_context, + tombstone_unpersisted_context, _, ) = await self.event_creation_handler.create_event( requester, @@ -226,6 +227,9 @@ class RoomCreationHandler: }, }, ) + tombstone_context = await tombstone_unpersisted_context.persist( + tombstone_event + ) validate_event_for_room_version(tombstone_event) await self._event_auth_handler.check_auth_rules_from_context( tombstone_event @@ -691,13 +695,14 @@ class RoomCreationHandler: config: JsonDict, ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, - ) -> Tuple[dict, int]: + ) -> Tuple[str, Optional[RoomAlias], int]: """Creates a new room. Args: - requester: - The user who requested the room creation. - config : A dict of configuration options. + requester: The user who requested the room creation. + config: A dict of configuration options. This will be the body of + a /createRoom request; see + https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom ratelimit: set to False to disable the rate limiter creator_join_profile: @@ -708,14 +713,17 @@ class RoomCreationHandler: `avatar_url` and/or `displayname`. Returns: - First, a dict containing the keys `room_id` and, if an alias - was, requested, `room_alias`. Secondly, the stream_id of the - last persisted event. + A 3-tuple containing: + - the room ID; + - if requested, the room alias, otherwise None; and + - the `stream_id` of the last persisted event. Raises: - SynapseError if the room ID couldn't be stored, 3pid invitation config - validation failed, or something went horribly wrong. - ResourceLimitError if server is blocked to some resource being - exceeded + SynapseError: + if the room ID couldn't be stored, 3pid invitation config + validation failed, or something went horribly wrong. + ResourceLimitError: + if server is blocked to some resource being + exceeded """ user_id = requester.user.to_string() @@ -865,9 +873,11 @@ class RoomCreationHandler: ) # Check whether this visibility value is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: @@ -1025,11 +1035,6 @@ class RoomCreationHandler: last_sent_event_id = member_event_id depth += 1 - result = {"room_id": room_id} - - if room_alias: - result["room_alias"] = room_alias.to_string() - # Always wait for room creation to propagate before returning await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), @@ -1037,7 +1042,7 @@ class RoomCreationHandler: last_stream_id, ) - return result, last_stream_id + return room_id, room_alias, last_stream_id async def _send_events_for_new_room( self, @@ -1092,7 +1097,11 @@ class RoomCreationHandler: content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.EventContext, Optional[dict]]: + ) -> Tuple[ + EventBase, + synapse.events.snapshot.UnpersistedEventContextBase, + Optional[dict], + ]: """ Creates an event and associated event context. Args: @@ -1113,7 +1122,7 @@ class RoomCreationHandler: ( new_event, - new_context, + new_unpersisted_context, third_party_event, ) = await self.event_creation_handler.create_event( creator, @@ -1122,13 +1131,13 @@ class RoomCreationHandler: depth=depth, state_map=state_map, for_batch=for_batch, - current_state_group=current_state_group, ) + depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_context, third_party_event + return new_event, new_unpersisted_context, third_party_event try: config = self._presets_dict[preset_config] @@ -1138,10 +1147,10 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - creation_event, creation_context, _ = await create_event( + creation_event, unpersisted_creation_context, _ = await create_event( EventTypes.Create, creation_content, False ) - + creation_context = await unpersisted_creation_context.persist(creation_event) logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1186,7 +1195,6 @@ class RoomCreationHandler: power_event, power_context, power_tp_event = await create_event( EventTypes.PowerLevels, pl_content, True ) - current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) if power_tp_event: third_party_events_to_append.append(power_tp_event) @@ -1237,7 +1245,6 @@ class RoomCreationHandler: power_level_content, True, ) - current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) if pl_tp_event: third_party_events_to_append.append(pl_tp_event) @@ -1246,7 +1253,6 @@ class RoomCreationHandler: room_alias_event, room_alias_context, ra_tp_event = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) - current_state_group = room_alias_context._state_group events_to_send.append((room_alias_event, room_alias_context)) if ra_tp_event: third_party_events_to_append.append(ra_tp_event) @@ -1257,7 +1263,6 @@ class RoomCreationHandler: {"join_rule": config["join_rules"]}, True, ) - current_state_group = join_rules_context._state_group events_to_send.append((join_rules_event, join_rules_context)) if jr_tp_event: third_party_events_to_append.append(jr_tp_event) @@ -1268,7 +1273,6 @@ class RoomCreationHandler: {"history_visibility": config["history_visibility"]}, True, ) - current_state_group = visibility_context._state_group events_to_send.append((visibility_event, visibility_context)) if vis_tp_event: third_party_events_to_append.append(vis_tp_event) @@ -1284,7 +1288,6 @@ class RoomCreationHandler: {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) - current_state_group = guest_access_context._state_group events_to_send.append((guest_access_event, guest_access_context)) if ga_tp_event: third_party_events_to_append.append(ga_tp_event) @@ -1293,7 +1296,6 @@ class RoomCreationHandler: event, context, tp_event = await create_event( etype, content, True, state_key=state_key ) - current_state_group = context._state_group events_to_send.append((event, context)) if tp_event: third_party_events_to_append.append(tp_event) @@ -1325,9 +1327,16 @@ class RoomCreationHandler: context = await unpersisted_context.persist(event) events_to_send.append((event, context)) + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_to_send, room_id, current_state_group, datastore + ) + ) + last_event = await self.event_creation_handler.handle_new_client_event( creator, - events_to_send, + events_and_context, ignore_shadow_ban=True, ratelimit=False, ) @@ -1867,7 +1876,7 @@ class RoomShutdownHandler: new_room_user_id, authenticated_entity=requester_user_id ) - info, stream_id = await self._room_creation_handler.create_room( + new_room_id, _, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": RoomCreationPreset.PUBLIC_CHAT, @@ -1876,7 +1885,6 @@ class RoomShutdownHandler: }, ratelimit=False, ) - new_room_id = info["room_id"] logger.info( "Shutting down room %r, joining to new room: %r", room_id, new_room_id @@ -1929,6 +1937,7 @@ class RoomShutdownHandler: # Join users to new room if new_room_user_id: + assert new_room_id is not None await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, @@ -1961,6 +1970,7 @@ class RoomShutdownHandler: aliases_for_room = await self.store.get_aliases_for_room(room_id) + assert new_room_id is not None await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 3d432ac295..8b5e02af17 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -327,7 +327,11 @@ class RoomBatchHandler: # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, context, _ = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + _, + ) = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), @@ -345,7 +349,7 @@ class RoomBatchHandler: historical=True, depth=inherited_depth, ) - + context = await unpersisted_context.persist(event) assert context._state_group # Normally this is done when persisting the event but we have to @@ -374,7 +378,7 @@ class RoomBatchHandler: # correct stream_ordering as they are backfilled (which decrements). # Events are sorted by (topological_ordering, stream_ordering) # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): + for event, context in reversed(events_to_persist): # This call can't raise `PartialStateConflictError` since we forbid # use of the historical batch API during partial state await self.event_creation_handler.handle_new_client_event( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index f136ee0e97..9d3096df8d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -207,6 +207,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): @abc.abstractmethod async def remote_knock( self, + requester: Requester, remote_room_hosts: List[str], room_id: str, user: UserID, @@ -416,7 +417,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): try: ( event, - context, + unpersisted_context, third_party_event, ) = await self.event_creation_handler.create_event( requester, @@ -439,7 +440,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier=outlier, historical=historical, ) - + context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -1088,7 +1089,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) return await self.remote_knock( - remote_room_hosts, room_id, target, content + requester, remote_room_hosts, room_id, target, content ) return await self._local_membership_update( @@ -1964,7 +1965,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): try: ( event, - context, + unpersisted_context, third_party_event_dict, ) = await self.event_creation_handler.create_event( requester, @@ -1974,6 +1975,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): auth_event_ids=auth_event_ids, outlier=True, ) + context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True events_and_context = [(event, context)] @@ -2013,6 +2015,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def remote_knock( self, + requester: Requester, remote_room_hosts: List[str], room_id: str, user: UserID, diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index ba261702d4..76e36b8a6d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -113,6 +113,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): async def remote_knock( self, + requester: Requester, remote_room_hosts: List[str], room_id: str, user: UserID, @@ -123,9 +124,10 @@ class RoomMemberWorkerHandler(RoomMemberHandler): Implements RoomMemberHandler.remote_knock """ ret = await self._remote_knock_client( + requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, - user=user, + user_id=user.to_string(), content=content, ) return ret["event_id"], ret["stream_id"] diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9bbf83047d..aad4706f14 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,7 +23,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID +from synapse.events.utils import SerializeEventConfig +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -109,12 +110,12 @@ class SearchHandler: return historical_room_ids async def search( - self, user: UserID, content: JsonDict, batch: Optional[str] = None + self, requester: Requester, content: JsonDict, batch: Optional[str] = None ) -> JsonDict: """Performs a full text search for a user. Args: - user: The user performing the search. + requester: The user performing the search. content: Search parameters batch: The next_batch parameter. Used for pagination. @@ -199,7 +200,7 @@ class SearchHandler: ) return await self._search( - user, + requester, batch_group, batch_group_key, batch_token, @@ -217,7 +218,7 @@ class SearchHandler: async def _search( self, - user: UserID, + requester: Requester, batch_group: Optional[str], batch_group_key: Optional[str], batch_token: Optional[str], @@ -235,7 +236,7 @@ class SearchHandler: """Performs a full text search for a user. Args: - user: The user performing the search. + requester: The user performing the search. batch_group: Pagination information. batch_group_key: Pagination information. batch_token: Pagination information. @@ -269,7 +270,7 @@ class SearchHandler: # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( - user.to_string(), + requester.user.to_string(), membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) @@ -303,13 +304,13 @@ class SearchHandler: if order_by == "rank": search_result, sender_group = await self._search_by_rank( - user, room_ids, search_term, keys, search_filter + requester.user, room_ids, search_term, keys, search_filter ) # Unused return values for rank search. global_next_batch = None elif order_by == "recent": search_result, global_next_batch = await self._search_by_recent( - user, + requester.user, room_ids, search_term, keys, @@ -334,7 +335,7 @@ class SearchHandler: assert after_limit is not None contexts = await self._calculate_event_contexts( - user, + requester.user, search_result.allowed_events, before_limit, after_limit, @@ -363,27 +364,37 @@ class SearchHandler: # The returned events. search_result.allowed_events, ), - user.to_string(), + requester.user.to_string(), ) # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise, the 'age' will be wrong. time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now, bundle_aggregations=aggregations + context["events_before"], + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now, bundle_aggregations=aggregations + context["events_after"], + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ) results = [ { "rank": search_result.rank_map[e.event_id], "result": self._event_serializer.serialize_event( - e, time_now, bundle_aggregations=aggregations + e, + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ), "context": contexts.get(e.event_id, {}), } @@ -398,7 +409,9 @@ class SearchHandler: if state_results: rooms_cat_res["state"] = { - room_id: self._event_serializer.serialize_events(state_events, time_now) + room_id: self._event_serializer.serialize_events( + state_events, time_now, config=serialize_options + ) for room_id, state_events in state_results.items() } diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4e4595312c..fd6d946c37 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1297,7 +1297,6 @@ class SyncHandler: return RoomNotifCounts.empty() with Measure(self.clock, "unread_notifs_for_room_id"): - return await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 332edcca24..78a75bfed6 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -13,7 +13,8 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type from twisted.web.client import PartialDownloadError @@ -27,19 +28,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class UserInteractiveAuthChecker: +class UserInteractiveAuthChecker(ABC): """Abstract base class for an interactive auth checker""" - def __init__(self, hs: "HomeServer"): + # This should really be an "abstract class property", i.e. it should + # be an error to instantiate a subclass that doesn't specify an AUTH_TYPE. + # But calling this a `ClassVar` is simpler than a decorator stack of + # @property @abstractmethod and @classmethod (if that's even the right order). + AUTH_TYPE: ClassVar[str] + + def __init__(self, hs: "HomeServer"): # noqa: B027 pass + @abstractmethod def is_enabled(self) -> bool: """Check if the configuration of the homeserver allows this checker to work Returns: True if this login type is enabled. """ + raise NotImplementedError() + @abstractmethod async def check_auth(self, authdict: dict, clientip: str) -> Any: """Given the authentication dict from the client, attempt to check this step @@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): ) -INTERACTIVE_AUTH_CHECKERS = [ +INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, diff --git a/synapse/http/client.py b/synapse/http/client.py index a05f297933..ae48e7c3f0 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -44,6 +44,7 @@ from twisted.internet.interfaces import ( IAddress, IDelayedCall, IHostResolution, + IOpenSSLContextFactory, IReactorCore, IReactorPluggableNameResolver, IReactorTime, @@ -958,8 +959,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): self._context = SSL.Context(SSL.SSLv23_METHOD) self._context.set_verify(VERIFY_NONE, lambda *_: False) - def getContext(self, hostname=None, port=None): + def getContext(self) -> SSL.Context: return self._context - def creatorForNetloc(self, hostname: bytes, port: int): + def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory: return self diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b92f1d3d1a..3302d4e48a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -440,7 +440,7 @@ class MatrixFederationHttpClient: Args: request: details of request to be sent - retry_on_dns_fail: true if the request should be retied on DNS failures + retry_on_dns_fail: true if the request should be retried on DNS failures timeout: number of milliseconds to wait for the response headers (including connecting to the server), *for each attempt*. @@ -475,7 +475,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -871,7 +871,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -958,7 +958,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1036,6 +1036,8 @@ class MatrixFederationHttpClient: args: A dictionary used to create query strings, defaults to None. + retry_on_dns_fail: true if the request should be retried on DNS failures + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. @@ -1063,7 +1065,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1141,7 +1143,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1197,7 +1199,7 @@ class MatrixFederationHttpClient: (except 429). NotRetryingDestination: If we are not yet ready to retry this server. - FederationDeniedError: If this destination is not on our + FederationDeniedError: If this destination is not on our federation whitelist RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. @@ -1267,7 +1269,7 @@ class MatrixFederationHttpClient: def _flatten_response_never_received(e: BaseException) -> str: if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined] + _flatten_response_never_received(f.value) for f in e.reasons ) return "%s:[%s]" % (type(e).__name__, reasons) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 6c7cf1b294..c70eee649c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -188,7 +188,7 @@ from typing import ( ) import attr -from typing_extensions import ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.internet import defer from twisted.web.http import Request @@ -445,7 +445,7 @@ def init_tracer(hs: "HomeServer") -> None: opentracing = None # type: ignore[assignment] return - if not opentracing or not JaegerConfig: + if opentracing is None or JaegerConfig is None: raise ConfigError( "The server has been configured to use opentracing but opentracing is not " "installed." @@ -524,6 +524,7 @@ def whitelisted_homeserver(destination: str) -> bool: # Start spans and scopes + # Could use kwargs but I want these to be explicit def start_active_span( operation_name: str, @@ -872,7 +873,7 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte def _custom_sync_async_decorator( func: Callable[P, R], - wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]], + wrapping_logic: Callable[Concatenate[Callable[P, R], P], ContextManager[None]], ) -> Callable[P, R]: """ Decorates a function that is sync or async (coroutines), or that returns a Twisted @@ -902,10 +903,14 @@ def _custom_sync_async_decorator( """ if inspect.iscoroutinefunction(func): - + # In this branch, R = Awaitable[RInner], for some other type RInner @wraps(func) - async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + async def _wrapper( + *args: P.args, **kwargs: P.kwargs + ) -> Any: # Return type is RInner with wrapping_logic(func, *args, **kwargs): + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. return await func(*args, **kwargs) # type: ignore[misc] else: @@ -972,7 +977,11 @@ def trace_with_opname( if not opentracing: return func - return _custom_sync_async_decorator(func, _wrapping_logic) + # type-ignore: mypy seems to be confused by the ParamSpecs here. + # I think the problem is https://github.com/python/mypy/issues/12909 + return _custom_sync_async_decorator( + func, _wrapping_logic # type: ignore[arg-type] + ) return _decorator @@ -1018,7 +1027,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield - return _custom_sync_async_decorator(func, _wrapping_logic) + # type-ignore: mypy seems to be confused by the ParamSpecs here. + # I think the problem is https://github.com/python/mypy/issues/12909 + return _custom_sync_async_decorator(func, _wrapping_logic) # type: ignore[arg-type] @contextlib.contextmanager diff --git a/synapse/media/_base.py b/synapse/media/_base.py new file mode 100644 index 0000000000..ef8334ae25 --- /dev/null +++ b/synapse/media/_base.py @@ -0,0 +1,479 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import urllib +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type + +import attr + +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender +from twisted.web.server import Request + +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.http.server import finish_request, respond_with_json +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name + +logger = logging.getLogger(__name__) + +# list all text content types that will have the charset default to UTF-8 when +# none is given +TEXT_CONTENT_TYPES = [ + "text/css", + "text/csv", + "text/html", + "text/calendar", + "text/plain", + "text/javascript", + "application/json", + "application/ld+json", + "application/rtf", + "image/svg+xml", + "text/xml", +] + + +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ + try: + # The type on postpath seems incorrect in Twisted 21.2.0. + postpath: List[bytes] = request.postpath # type: ignore + assert postpath + + # This allows users to append e.g. /test.png to the URL. Useful for + # clients that parse the URL to see content type. + server_name_bytes, media_id_bytes = postpath[:2] + server_name = server_name_bytes.decode("utf-8") + media_id = media_id_bytes.decode("utf8") + + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + + file_name = None + if len(postpath) > 2: + try: + file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) + except UnicodeDecodeError: + pass + return server_name, media_id, file_name + except Exception: + raise SynapseError( + 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN + ) + + +def respond_404(request: SynapseRequest) -> None: + respond_with_json( + request, + 404, + cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), + send_cors=True, + ) + + +async def respond_with_file( + request: SynapseRequest, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: + logger.debug("Responding with %r", file_path) + + if os.path.isfile(file_path): + if file_size is None: + stat = os.stat(file_path) + file_size = stat.st_size + + add_file_headers(request, media_type, file_size, upload_name) + + with open(file_path, "rb") as f: + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + + finish_request(request) + else: + respond_404(request) + + +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. + """ + + def _quote(x: str) -> str: + return urllib.parse.quote(x.encode("utf-8")) + + # Default to a UTF-8 charset for text content types. + # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' + if media_type.lower() in TEXT_CONTENT_TYPES: + content_type = media_type + "; charset=UTF-8" + else: + content_type = media_type + + request.setHeader(b"Content-Type", content_type.encode("UTF-8")) + if upload_name: + # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. + # + # `filename` is defined to be a `value`, which is defined by RFC2616 + # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` + # is (essentially) a single US-ASCII word, and a `quoted-string` is a + # US-ASCII string surrounded by double-quotes, using backslash as an + # escape character. Note that %-encoding is *not* permitted. + # + # `filename*` is defined to be an `ext-value`, which is defined in + # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, + # where `value-chars` is essentially a %-encoded string in the given charset. + # + # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 + # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 + # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 + + # We avoid the quoted-string version of `filename`, because (a) synapse didn't + # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we + # may as well just do the filename* version. + if _can_encode_filename_as_token(upload_name): + disposition = "inline; filename=%s" % (upload_name,) + else: + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) + + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) + + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + + +# separators as defined in RFC2616. SP and HT are handled separately. +# see _can_encode_filename_as_token. +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} + + +def _can_encode_filename_as_token(x: str) -> bool: + for c in x: + # from RFC2616: + # + # token = 1*<any CHAR except CTLs or separators> + # + # separators = "(" | ")" | "<" | ">" | "@" + # | "," | ";" | ":" | "\" | <"> + # | "/" | "[" | "]" | "?" | "=" + # | "{" | "}" | SP | HT + # + # CHAR = <any US-ASCII character (octets 0 - 127)> + # + # CTL = <any US-ASCII control character + # (octets 0 - 31) and DEL (127)> + # + if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: + return False + return True + + +async def respond_with_responder( + request: SynapseRequest, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + # If we have a responder we *must* use it as a context manager. + with responder: + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + logger.debug("Responding to media request with responder %s", responder) + add_file_headers(request, media_type, file_size, upload_name) + try: + await responder.write_to_consumer(request) + except Exception as e: + # The majority of the time this will be due to the client having gone + # away. Unfortunately, Twisted simply throws a generic exception at us + # in that case. + logger.warning("Failed to write to consumer: %s %s", type(e), e) + + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() + + finish_request(request) + + +class Responder(ABC): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + + @abstractmethod + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: + """Stream response into consumer + + Args: + consumer: The consumer to stream into. + + Returns: + Resolves once the response has finished being written + """ + raise NotImplementedError() + + def __enter__(self) -> None: # noqa: B027 + pass + + def __exit__( # noqa: B027 + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThumbnailInfo: + """Details about a generated thumbnail.""" + + width: int + height: int + method: str + # Content type of thumbnail, e.g. image/png + type: str + # The size of the media file, in bytes. + length: Optional[int] = None + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FileInfo: + """Details about a requested/uploaded file.""" + + # The server name where the media originated from, or None if local. + server_name: Optional[str] + # The local ID of the file. For local files this is the same as the media_id + file_id: str + # If the file is for the url preview cache + url_cache: bool = False + # Whether the file is a thumbnail or not. + thumbnail: Optional[ThumbnailInfo] = None + + # The below properties exist to maintain compatibility with third-party modules. + @property + def thumbnail_width(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.width + + @property + def thumbnail_height(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.height + + @property + def thumbnail_method(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.method + + @property + def thumbnail_type(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.type + + @property + def thumbnail_length(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.length + + +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: + """ + Get the filename of the downloaded file by inspecting the + Content-Disposition HTTP header. + + Args: + headers: The HTTP request headers. + + Returns: + The filename, or None. + """ + content_disposition = headers.get(b"Content-Disposition", [b""]) + + # No header, bail out. + if not content_disposition[0]: + return None + + _, params = _parse_header(content_disposition[0]) + + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get(b"filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith(b"utf-8''"): + upload_name_utf8 = upload_name_utf8[7:] + # We have a filename*= section. This MUST be ASCII, and any UTF-8 + # bytes are %-quoted. + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get(b"filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii.decode("ascii") + + # This may be None here, indicating we did not find a matching name. + return upload_name + + +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: + """Parse a Content-type like header. + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + line: header to be parsed + + Returns: + The main content-type, followed by the parameter dictionary + """ + parts = _parseparam(b";" + line) + key = next(parts) + pdict = {} + for p in parts: + i = p.find(b"=") + if i >= 0: + name = p[:i].strip().lower() + value = p[i + 1 :].strip() + + # strip double-quotes + if len(value) >= 2 and value[0:1] == value[-1:] == b'"': + value = value[1:-1] + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') + pdict[name] = value + + return key, pdict + + +def _parseparam(s: bytes) -> Generator[bytes, None, None]: + """Generator which splits the input on ;, respecting double-quoted sequences + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + s: header to be parsed + + Returns: + The split input + """ + while s[:1] == b";": + s = s[1:] + + # look for the next ; + end = s.find(b";") + + # if there is an odd number of " marks between here and the next ;, skip to the + # next ; instead + while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: + end = s.find(b";", end + 1) + + if end < 0: + end = len(s) + f = s[:end] + yield f.strip() + s = s[end:] diff --git a/synapse/rest/media/v1/filepath.py b/synapse/media/filepath.py index 1f6441c412..1f6441c412 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/media/filepath.py diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/media/media_repository.py index c70e1837af..b81e3c2b0c 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/media/media_repository.py @@ -32,18 +32,10 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.config._base import ConfigError from synapse.config.repository import ThumbnailRequirement -from synapse.http.server import UnrecognizedRequestResource from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID -from synapse.util.async_helpers import Linearizer -from synapse.util.retryutils import NotRetryingDestination -from synapse.util.stringutils import random_string - -from ._base import ( +from synapse.media._base import ( FileInfo, Responder, ThumbnailInfo, @@ -51,15 +43,15 @@ from ._base import ( respond_404, respond_with_responder, ) -from .config_resource import MediaConfigResource -from .download_resource import DownloadResource -from .filepath import MediaFilePaths -from .media_storage import MediaStorage -from .preview_url_resource import PreviewUrlResource -from .storage_provider import StorageProviderWrapper -from .thumbnail_resource import ThumbnailResource -from .thumbnailer import Thumbnailer, ThumbnailError -from .upload_resource import UploadResource +from synapse.media.filepath import MediaFilePaths +from synapse.media.media_storage import MediaStorage +from synapse.media.storage_provider import StorageProviderWrapper +from synapse.media.thumbnailer import Thumbnailer, ThumbnailError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.async_helpers import Linearizer +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -1044,69 +1036,3 @@ class MediaRepository: removed_media.append(media_id) return removed_media, len(removed_media) - - -class MediaRepositoryResource(UnrecognizedRequestResource): - """File uploading and downloading. - - Uploads are POSTed to a resource which returns a token which is used to GET - the download:: - - => POST /_matrix/media/r0/upload HTTP/1.1 - Content-Type: <media-type> - Content-Length: <content-length> - - <media> - - <= HTTP/1.1 200 OK - Content-Type: application/json - - { "content_uri": "mxc://<server-name>/<media-id>" } - - => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: <media-type> - Content-Disposition: attachment;filename=<upload-filename> - - <media> - - Clients can get thumbnails by supplying a desired width and height and - thumbnailing method:: - - => GET /_matrix/media/r0/thumbnail/<server_name> - /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: image/jpeg or image/png - - <thumbnail> - - The thumbnail methods are "crop" and "scale". "scale" tries to return an - image where either the width or the height is smaller than the requested - size. The client should then scale and letterbox the image if it needs to - fit within a given rectangle. "crop" tries to return an image where the - width and height are close to the requested size and the aspect matches - the requested size. The client should scale the image if it needs to fit - within a given rectangle. - """ - - def __init__(self, hs: "HomeServer"): - # If we're not configured to use it, raise if we somehow got here. - if not hs.config.media.can_load_media_repo: - raise ConfigError("Synapse is not configured to use a media repo.") - - super().__init__() - media_repo = hs.get_media_repository() - - self.putChild(b"upload", UploadResource(hs, media_repo)) - self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild( - b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) - ) - if hs.config.media.url_preview_enabled: - self.putChild( - b"preview_url", - PreviewUrlResource(hs, media_repo, media_repo.media_storage), - ) - self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py new file mode 100644 index 0000000000..a7e22a91e1 --- /dev/null +++ b/synapse/media/media_storage.py @@ -0,0 +1,374 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import os +import shutil +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + Awaitable, + BinaryIO, + Callable, + Generator, + Optional, + Sequence, + Tuple, + Type, +) + +import attr + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender + +import synapse +from synapse.api.errors import NotFoundError +from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock +from synapse.util.file_consumer import BackgroundFileConsumer + +from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.media.storage_provider import StorageProvider + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MediaStorage: + """Responsible for storing/fetching files from local sources. + + Args: + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. + """ + + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): + self.hs = hs + self.reactor = hs.get_reactor() + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() + + async def store_file(self, source: IO, file_info: FileInfo) -> str: + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info: Info about the file to store + + Returns: + the file path written to in the primary media store + """ + + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + await self.write_to_file(source, f) + await finish_cb() + + return fname + + async def write_to_file(self, source: IO, output: IO) -> None: + """Asynchronously write the `source` to `output`.""" + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + + @contextlib.contextmanager + def store_into_file( + self, file_info: FileInfo + ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns an awaitable. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info: Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + await finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + os.makedirs(dirname, exist_ok=True) + + finished_called = [False] + + try: + with open(fname, "wb") as f: + + async def finish() -> None: + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + spam_check = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam_check != synapse.module_api.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + + yield f, fname, finish + except Exception as e: + try: + os.remove(fname) + except Exception: + pass + + raise e from None + + if not finished_called: + raise Exception("Finished callback not called") + + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info + + Returns: + Returns a Responder if the file was found, otherwise None. + """ + paths = [self._file_info_to_path(file_info)] + + # fallback for remote thumbnails with no method in the filename + if file_info.thumbnail and file_info.server_name: + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + ) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) + + for provider in self.storage_providers: + for path in paths: + res: Any = await provider.fetch(path, file_info) + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) + + return None + + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: + """Ensures that the given file is in the local cache. Attempts to + download it from storage providers if it isn't. + + Args: + file_info + + Returns: + Full path to local file + """ + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + return local_path + + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + + dirname = os.path.dirname(local_path) + os.makedirs(dirname, exist_ok=True) + + for provider in self.storage_providers: + res: Any = await provider.fetch(path, file_info) + if res: + with res: + consumer = BackgroundFileConsumer( + open(local_path, "wb"), self.reactor + ) + await res.write_to_consumer(consumer) + await consumer.wait() + return local_path + + raise NotFoundError() + + def _file_info_to_path(self, file_info: FileInfo) -> str: + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + """ + if file_info.url_cache: + if file_info.thumbnail: + return self.filepaths.url_cache_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) + + +def _write_file_synchronously(source: IO, dest: IO) -> None: + """Write `source` to the file like `dest` synchronously. Should be called + from a thread. + + Args: + source: A file like object that's to be written + dest: A file like object to be written to + """ + source.seek(0) # Ensure we read from the start of the file + shutil.copyfileobj(source, dest) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file: A file like object to be streamed ot the client, + is closed when finished streaming. + """ + + def __init__(self, open_file: IO): + self.open_file = open_file + + def write_to_consumer(self, consumer: IConsumer) -> Deferred: + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True, auto_attribs=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2**14 + + clock: Clock + path: str + + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: + """Reads the file in chunks and calls the callback with each chunk.""" + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/rest/media/v1/oembed.py b/synapse/media/oembed.py index 7592aa5d47..c0eaf04be5 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/media/oembed.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional import attr -from synapse.rest.media.v1.preview_html import parse_html_description +from synapse.media.preview_html import parse_html_description from synapse.types import JsonDict from synapse.util import json_decoder diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/media/preview_html.py index 516d0434f0..516d0434f0 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/media/preview_html.py diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py new file mode 100644 index 0000000000..1c9b71d69c --- /dev/null +++ b/synapse/media/storage_provider.py @@ -0,0 +1,181 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +import os +import shutil +from typing import TYPE_CHECKING, Callable, Optional + +from synapse.config._base import Config +from synapse.logging.context import defer_to_thread, run_in_background +from synapse.util.async_helpers import maybe_awaitable + +from ._base import FileInfo, Responder +from .media_storage import FileResponder + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StorageProvider(metaclass=abc.ABCMeta): + """A storage provider is a service that can store uploaded media and + retrieve them. + """ + + @abc.abstractmethod + async def store_file(self, path: str, file_info: FileInfo) -> None: + """Store the file described by file_info. The actual contents can be + retrieved by reading the file in file_info.upload_path. + + Args: + path: Relative path of file in local cache + file_info: The metadata of the file. + """ + + @abc.abstractmethod + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + """Attempt to fetch the file described by file_info and stream it + into writer. + + Args: + path: Relative path of file in local cache + file_info: The metadata of the file. + + Returns: + Returns a Responder if the provider has the file, otherwise returns None. + """ + + +class StorageProviderWrapper(StorageProvider): + """Wraps a storage provider and provides various config options + + Args: + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully + uploaded, or todo the upload in the background. + store_remote: Whether remote media should be uploaded + """ + + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): + self.backend = backend + self.store_local = store_local + self.store_synchronous = store_synchronous + self.store_remote = store_remote + + def __str__(self) -> str: + return "StorageProviderWrapper[%s]" % (self.backend,) + + async def store_file(self, path: str, file_info: FileInfo) -> None: + if not file_info.server_name and not self.store_local: + return None + + if file_info.server_name and not self.store_remote: + return None + + if file_info.url_cache: + # The URL preview cache is short lived and not worth offloading or + # backing up. + return None + + if self.store_synchronous: + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore + else: + # TODO: Handle errors. + async def store() -> None: + try: + return await maybe_awaitable( + self.backend.store_file(path, file_info) + ) + except Exception: + logger.exception("Error storing file") + + run_in_background(store) + + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + if file_info.url_cache: + # Files in the URL preview cache definitely aren't stored here, + # so avoid any potentially slow I/O or network access. + return None + + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + return await maybe_awaitable(self.backend.fetch(path, file_info)) + + +class FileStorageProviderBackend(StorageProvider): + """A storage provider that stores files in a directory on a filesystem. + + Args: + hs + config: The config returned by `parse_config`. + """ + + def __init__(self, hs: "HomeServer", config: str): + self.hs = hs + self.cache_directory = hs.config.media.media_store_path + self.base_directory = config + + def __str__(self) -> str: + return "FileStorageProviderBackend[%s]" % (self.base_directory,) + + async def store_file(self, path: str, file_info: FileInfo) -> None: + """See StorageProvider.store_file""" + + primary_fname = os.path.join(self.cache_directory, path) + backup_fname = os.path.join(self.base_directory, path) + + dirname = os.path.dirname(backup_fname) + os.makedirs(dirname, exist_ok=True) + + # mypy needs help inferring the type of the second parameter, which is generic + shutil_copyfile: Callable[[str, str], str] = shutil.copyfile + await defer_to_thread( + self.hs.get_reactor(), + shutil_copyfile, + primary_fname, + backup_fname, + ) + + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + """See StorageProvider.fetch""" + + backup_fname = os.path.join(self.base_directory, path) + if os.path.isfile(backup_fname): + return FileResponder(open(backup_fname, "rb")) + + return None + + @staticmethod + def parse_config(config: dict) -> str: + """Called on startup to parse config supplied. This should parse + the config and raise if there is a problem. + + The returned value is passed into the constructor. + + In this case we only care about a single param, the directory, so let's + just pull that out. + """ + return Config.ensure_directory(config["directory"]) diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/media/thumbnailer.py index 9480cc5763..f909a4fb9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -38,7 +38,6 @@ class ThumbnailError(Exception): class Thumbnailer: - FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} @staticmethod diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index b01372565d..8ce5887229 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -87,7 +87,6 @@ class LaterGauge(Collector): ] def collect(self) -> Iterable[Metric]: - g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) try: diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py index b7d47ce3e7..a22c4e5bbd 100644 --- a/synapse/metrics/_gc.py +++ b/synapse/metrics/_gc.py @@ -139,7 +139,6 @@ def install_gc_manager() -> None: class PyPyGCStats(Collector): def collect(self) -> Iterable[Metric]: - # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB'). diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d22dd19d38..424239e3df 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -64,9 +64,11 @@ from synapse.events.third_party_rules import ( CHECK_EVENT_ALLOWED_CALLBACK, CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, ON_PROFILE_UPDATE_CALLBACK, + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, ON_THREEPID_BIND_CALLBACK, ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) @@ -357,6 +359,12 @@ class ModuleApi: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, + on_add_user_third_party_identifier: Optional[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, + on_remove_user_third_party_identifier: Optional[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -373,6 +381,8 @@ class ModuleApi: on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, on_threepid_bind=on_threepid_bind, + on_add_user_third_party_identifier=on_add_user_third_party_identifier, + on_remove_user_third_party_identifier=on_remove_user_third_party_identifier, ) def register_presence_router_callbacks( @@ -1576,14 +1586,14 @@ class ModuleApi: ) requester = create_requester(user_id) - room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room( + room_id, room_alias, _ = await self._hs.get_room_creation_handler().create_room( requester=requester, config=config, ratelimit=ratelimit, creator_join_profile=creator_join_profile, ) - - return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None) + room_alias_str = room_alias.to_string() if room_alias else None + return room_id, room_alias_str async def set_displayname( self, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 2e917c90c4..ba12b6d79a 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -23,7 +23,6 @@ from typing import ( Mapping, Optional, Sequence, - Set, Tuple, Union, ) @@ -276,7 +275,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events[relation_type] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) reply_event_id = ( @@ -294,7 +293,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events["m.in_reply_to"] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) # indicate that this is from a fallback relation. @@ -330,7 +329,6 @@ class BulkPushRuleEvaluator: context: EventContext, event_id_to_event: Mapping[str, EventBase], ) -> None: - if ( not event.internal_metadata.is_notifiable() or event.internal_metadata.is_historical() @@ -397,30 +395,17 @@ class BulkPushRuleEvaluator: del notification_levels[key] # Pull out any user and room mentions. - mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) - has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict) - user_mentions: Set[str] = set() - room_mention = False - if has_mentions: - # mypy seems to have lost the type even though it must be a dict here. - assert isinstance(mentions, dict) - # Remove out any non-string items and convert to a set. - user_mentions_raw = mentions.get("user_ids") - if isinstance(user_mentions_raw, list): - user_mentions = set( - filter(lambda item: isinstance(item, str), user_mentions_raw) - ) - # Room mention is only true if the value is exactly true. - room_mention = mentions.get("room") is True + has_mentions = ( + self._intentional_mentions_enabled + and EventContentFields.MSC3952_MENTIONS in event.content + ) evaluator = PushRuleEvaluator( _flatten_dict( event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ), has_mentions, - user_mentions, - room_mention, room_member_count, sender_power_level, notification_levels, @@ -428,7 +413,6 @@ class BulkPushRuleEvaluator: self._related_event_match_enabled, event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag - self.hs.config.experimental.msc3758_exact_event_match, self.hs.config.experimental.msc3966_exact_event_property_contains, ) @@ -512,7 +496,7 @@ def _flatten_dict( prefix: Optional[List[str]] = None, result: Optional[Dict[str, JsonValue]] = None, *, - msc3783_escape_event_match_key: bool = False, + msc3873_escape_event_match_key: bool = False, ) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, @@ -541,7 +525,7 @@ def _flatten_dict( if result is None: result = {} for key, value in d.items(): - if msc3783_escape_event_match_key: + if msc3873_escape_event_match_key: # Escape periods in the key with a backslash (and backslashes with an # extra backslash). This is since a period is used as a separator between # nested fields. @@ -557,7 +541,7 @@ def _flatten_dict( value, prefix=(prefix + [key]), result=result, - msc3783_escape_event_match_key=msc3783_escape_event_match_key, + msc3873_escape_event_match_key=msc3873_escape_event_match_key, ) # `room_version` should only ever be set when looking at the top level of an event diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index bb76c169c6..222afbdcc8 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -41,11 +41,12 @@ def format_push_rules_for_user( rulearray.append(template_rule) - pattern_type = template_rule.pop("pattern_type", None) - if pattern_type == "user_id": - template_rule["pattern"] = user.to_string() - elif pattern_type == "user_localpart": - template_rule["pattern"] = user.localpart + for type_key in ("pattern", "value"): + type_value = template_rule.pop(f"{type_key}_type", None) + if type_value == "user_id": + template_rule[type_key] = user.to_string() + elif type_value == "user_localpart": + template_rule[type_key] = user.localpart template_rule["enabled"] = enabled diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 2374f810c9..111ec07e64 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -265,7 +265,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] - return {} async def _handle_request( # type: ignore[override] diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index ecea6fc915..cc3929dcf5 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -195,7 +195,6 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): async def _serialize_payload( # type: ignore[override] user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: - return { "user_id": user_id, "device_id": device_id, diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 9fa1060d48..67b01db67e 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -142,17 +142,12 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, - request: SynapseRequest, - content: JsonDict, - room_id: str, - user_id: str, + self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str ) -> Tuple[int, JsonDict]: remote_room_hosts = content["remote_room_hosts"] event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) - request.requester = requester logger.debug("remote_knock: %s on room: %s", user_id, room_id) @@ -277,16 +272,12 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): } async def _handle_request( # type: ignore[override] - self, - request: SynapseRequest, - content: JsonDict, - knock_event_id: str, + self, request: SynapseRequest, content: JsonDict, knock_event_id: str ) -> Tuple[int, JsonDict]: txn_id = content["txn_id"] event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) - request.requester = requester # hopefully we're now on the master, so this won't recurse! @@ -363,3 +354,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) + ReplicationRemoteKnockRestServlet(hs).register(http_server) + ReplicationRemoteRescindKnockRestServlet(hs).register(http_server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index cc0528bd8e..424854efbe 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -370,15 +370,23 @@ class ReplicationDataHandler: # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): logger.info( - "Waiting for repl stream %r to reach %s (%s)", + "Waiting for repl stream %r to reach %s (%s); currently at: %s", stream_name, position, instance_name, + current_position, ) try: await make_deferred_yieldable(deferred) except defer.TimeoutError: - logger.error("Timed out waiting for stream %s", stream_name) + logger.error( + "Timed out waiting for repl stream %r to reach %s (%s)" + "; currently at: %s", + stream_name, + position, + instance_name, + self._streams[stream_name].current_token(instance_name), + ) return logger.info( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index fd1c0ec6af..dfc061eb5e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -328,7 +328,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): outbound_redis_connection: txredisapi.ConnectionHandler, channel_names: List[str], ): - super().__init__( hs, uuid="subscriber", diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 9d17eff714..347467d863 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -238,6 +238,24 @@ class ReplicationStreamer: except Exception: logger.exception("Failed to replicate") + # The last token we send may not match the current + # token, in which case we want to send out a `POSITION` + # to tell other workers the actual current position. + if updates[-1][0] < current_token: + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + updates[-1][0], + current_token, + ) + ) + logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 14b6705862..ad9b760713 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -139,7 +139,6 @@ class EventsStream(Stream): current_token: Token, target_row_count: int, ) -> StreamUpdateResult: - # the events stream merges together three separate sources: # * new events # * current_state changes diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 14c4e6ebbb..2e19e055d3 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -108,8 +108,7 @@ class ClientRestResource(JsonResource): if is_main_process: logout.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource) - if is_main_process: - filter.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) if is_main_process: @@ -140,7 +139,7 @@ class ClientRestResource(JsonResource): relations.register_servlets(hs, client_resource) if is_main_process: password_policy.register_servlets(hs, client_resource) - knock.register_servlets(hs, client_resource) + knock.register_servlets(hs, client_resource) # moving to /_synapse/admin if is_main_process: diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index a3beb74e2c..c546ef7e23 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,11 +53,11 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) @@ -79,7 +79,7 @@ class EventReportsRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - event_reports, total = await self.store.get_event_reports_paginate( + event_reports, total = await self._store.get_event_reports_paginate( start, limit, direction, user_id, room_id ) ret = {"event_reports": event_reports, "total": total} @@ -108,13 +108,13 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) message = ( "The report_id parameter must be a string representing a positive integer." @@ -131,8 +131,33 @@ class EventReportDetailRestServlet(RestServlet): HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM ) - ret = await self.store.get_event_report(resolved_report_id) + ret = await self._store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") return HTTPStatus.OK, ret + + async def on_DELETE( + self, request: SynapseRequest, report_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if await self._store.delete_event_report(resolved_report_id): + return HTTPStatus.OK, {} + + raise NotFoundError("Event report not found") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1d6e4982d7..4de56bf13f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -75,7 +75,6 @@ class RoomRestV2Servlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) await assert_user_is_admin(self._auth, requester) @@ -144,7 +143,6 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): @@ -181,7 +179,6 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, delete_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) delete_status = self._pagination_handler.get_delete_status(delete_id) @@ -438,7 +435,6 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$") def __init__(self, hs: "HomeServer"): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 0c0bf540b9..357e9a574d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -304,13 +304,20 @@ class UserRestServletV2(RestServlet): # remove old threepids for medium, address in del_threepids: try: - await self.auth_handler.delete_threepid( - user_id, medium, address, None + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + await self.hs.get_identity_handler().try_unbind_threepid( + user_id, medium, address, id_server=None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, medium, address + ) + # add new threepids current_time = self.hs.get_clock().time_msec() for medium, address in add_threepids: @@ -683,8 +690,12 @@ class AccountValidityRenewServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if self.account_activity_handler.on_legacy_admin_request_callback: - expiration_ts = await ( - self.account_activity_handler.on_legacy_admin_request_callback(request) + expiration_ts = ( + await ( + self.account_activity_handler.on_legacy_admin_request_callback( + request + ) + ) ) else: body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 662f5bf762..484d7440a4 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -768,7 +768,9 @@ class ThreepidDeleteRestServlet(RestServlet): user_id = requester.user.to_string() try: - ret = await self.auth_handler.delete_threepid( + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + ret = await self.hs.get_identity_handler().try_unbind_threepid( user_id, body.medium, body.address, body.id_server ) except Exception: @@ -783,6 +785,11 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, body.medium, body.address + ) + return 200, {"id_server_unbind_result": id_server_unbind_result} diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index eb77337044..276a1b405d 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -97,7 +97,6 @@ class AuthRestServlet(RestServlet): return None async def on_POST(self, request: Request, stagetype: str) -> None: - session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 486c6dbbc5..dab4a77f7e 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -255,7 +255,7 @@ class DehydratedDeviceServlet(RestServlet): """ - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=()) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 782e7d14e8..694d77d287 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError +from synapse.events.utils import SerializeEventConfig from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_string from synapse.http.site import SynapseRequest @@ -43,9 +44,8 @@ class EventStreamRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - is_guest = requester.is_guest args: Dict[bytes, List[bytes]] = request.args # type: ignore - if is_guest: + if requester.is_guest: if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") @@ -63,13 +63,12 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( - requester.user.to_string(), + requester, pagin_config, timeout=timeout, as_client_event=as_client_event, - affect_presence=(not is_guest), + affect_presence=(not requester.is_guest), room_id=room_id, - is_guest=is_guest, ) return 200, chunk @@ -91,9 +90,12 @@ class EventRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) - time_now = self.clock.time_msec() if event: - result = self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event( + event, + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index cc1c2f9731..236199897c 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -79,7 +79,6 @@ class CreateFilterRestServlet(RestServlet): async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 7873b363c0..32bb8b9a91 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -312,15 +312,29 @@ class SigningKeyUploadServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a device signing key to your account", - # Allow skipping of UI auth since this is frequently called directly - # after login and it is silly to ask users to re-auth immediately. - can_skip_ui_auth=True, - ) + if self.hs.config.experimental.msc3967_enabled: + if await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id): + # If we already have a master key then cross signing is set up and we require UIA to reset + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "reset the device signing key on your account", + # Do not allow skipping of UIA auth. + can_skip_ui_auth=False, + ) + # Otherwise we don't require UIA since we are setting up cross signing for first time + else: + # Previous behaviour is to always require UIA but allow it to be skipped + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a device signing key to your account", + # Allow skipping of UI auth since this is frequently called directly + # after login and it is silly to ask users to re-auth immediately. + can_skip_ui_auth=True, + ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) return 200, result diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index ad025c8a45..4fa66904ba 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from synapse.api.constants import Membership from synapse.api.errors import SynapseError @@ -24,8 +24,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict, RoomAlias, RoomID if TYPE_CHECKING: @@ -45,7 +43,6 @@ class KnockRoomAliasServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.txns = HttpTransactionCache(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -53,7 +50,6 @@ class KnockRoomAliasServlet(RestServlet): self, request: SynapseRequest, room_identifier: str, - txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -67,7 +63,6 @@ class KnockRoomAliasServlet(RestServlet): # twisted.web.server.Request.args is incorrectly defined as Optional[Any] args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args( args, "server_name", required=False ) @@ -86,7 +81,6 @@ class KnockRoomAliasServlet(RestServlet): target=requester.user, room_id=room_id, action=Membership.KNOCK, - txn_id=txn_id, third_party_signed=None, remote_room_hosts=remote_room_hosts, content=event_content, @@ -94,15 +88,6 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT( - self, request: SynapseRequest, room_identifier: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 61268e3af1..ea10042569 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -72,6 +72,12 @@ class NotificationsServlet(RestServlet): next_token = None + serialize_options = SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id, + requester=requester, + ) + now = self.clock.time_msec() + for pa in push_actions: returned_pa = { "room_id": pa.room_id, @@ -81,10 +87,8 @@ class NotificationsServlet(RestServlet): "event": ( self._event_serializer.serialize_event( notif_events[pa.event_id], - self.clock.time_msec(), - config=SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id - ), + now, + config=serialize_options, ) ), } diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 3cb1e7e375..bce806f2bb 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -628,10 +628,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, + desired_username = ( + await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) ) ) @@ -682,9 +684,11 @@ class RegisterRestServlet(RestServlet): session_id ) - display_name = await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params + display_name = ( + await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) ) ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index d0db85cca7..61e4cf0213 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( UnredactedContentDeletedError, ) from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2 +from synapse.events.utils import SerializeEventConfig, format_event_for_client_v2 from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, @@ -160,11 +160,11 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - info, _ = await self._room_creation_handler.create_room( + room_id, _, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) - return 200, info + return 200, {"room_id": room_id} def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) @@ -814,11 +814,13 @@ class RoomEventServlet(RestServlet): [event], requester.user.to_string() ) - time_now = self.clock.time_msec() # per MSC2676, /rooms/{roomId}/event/{eventId}, should return the # *original* event, rather than the edited version event_dict = self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=aggregations, apply_edits=False + event, + self.clock.time_msec(), + bundle_aggregations=aggregations, + config=SerializeEventConfig(requester=requester), ) return 200, event_dict @@ -863,24 +865,30 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + serializer_options = SerializeEventConfig(requester=requester) results = { "events_before": self._event_serializer.serialize_events( event_context.events_before, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "event": self._event_serializer.serialize_event( event_context.event, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "events_after": self._event_serializer.serialize_events( event_context.events_after, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "state": self._event_serializer.serialize_events( - event_context.state, time_now + event_context.state, + time_now, + config=serializer_options, ), "start": event_context.start, "end": event_context.end, @@ -926,7 +934,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server: HttpServer) -> None: - # /rooms/$roomid/[invite|join|leave] + # /rooms/$roomid/[join|invite|leave|ban|unban|kick] PATTERNS = ( "/rooms/(?P<room_id>[^/]*)/" "(?P<membership_action>join|invite|leave|ban|unban|kick)" @@ -1192,7 +1200,7 @@ class SearchRestServlet(RestServlet): content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = await self.search_handler.search(requester.user, content, batch) + results = await self.search_handler.search(requester, content, batch) return 200, results diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 10be4a781b..ef284ecc11 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -15,9 +15,7 @@ import logging import re from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Tuple - -from twisted.web.server import Request +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError, Codes, SynapseError @@ -30,7 +28,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict if TYPE_CHECKING: @@ -79,7 +76,6 @@ class RoomBatchSendEventRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() - self.txns = HttpTransactionCache(hs) async def on_POST( self, request: SynapseRequest, room_id: str @@ -249,16 +245,6 @@ class RoomBatchSendEventRestServlet(RestServlet): return HTTPStatus.OK, response_dict - def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: - return HTTPStatus.NOT_IMPLEMENTED, "Not implemented" - - def on_PUT( - self, request: SynapseRequest, room_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f2013faeb2..e578b26fa3 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -16,7 +16,7 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.api.constants import EduTypes, Membership, PresenceState +from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -38,7 +38,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace_with_opname -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict, Requester, StreamToken from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -139,7 +139,28 @@ class SyncRestServlet(RestServlet): device_id, ) - request_key = (user, timeout, since, filter_id, full_state, device_id) + # Stream position of the last ignored users account data event for this user, + # if we're initial syncing. + # We include this in the request key to invalidate an initial sync + # in the response cache once the set of ignored users has changed. + # (We filter out ignored users from timeline events, so our sync response + # is invalid once the set of ignored users changes.) + last_ignore_accdata_streampos: Optional[int] = None + if not since: + # No `since`, so this is an initial sync. + last_ignore_accdata_streampos = await self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( + user.to_string(), AccountDataTypes.IGNORED_USER_LIST + ) + + request_key = ( + user, + timeout, + since, + filter_id, + full_state, + device_id, + last_ignore_accdata_streampos, + ) if filter_id is None: filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION @@ -205,7 +226,7 @@ class SyncRestServlet(RestServlet): # We know that the the requester has an access token since appservices # cannot use sync. response_content = await self.encode_response( - time_now, sync_result, requester.access_token_id, filter_collection + time_now, sync_result, requester, filter_collection ) logger.debug("Event formatting complete") @@ -216,7 +237,7 @@ class SyncRestServlet(RestServlet): self, time_now: int, sync_result: SyncResult, - access_token_id: Optional[int], + requester: Requester, filter: FilterCollection, ) -> JsonDict: logger.debug("Formatting events in sync response") @@ -229,12 +250,12 @@ class SyncRestServlet(RestServlet): serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, only_event_fields=filter.event_fields, ) stripped_serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, include_stripped_room_state=True, ) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/config_resource.py index a95804d327..a95804d327 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/config_resource.py diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/download_resource.py index 048a042692..8f270cf4cc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -22,11 +22,10 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_boolean from synapse.http.site import SynapseRequest - -from ._base import parse_media_id, respond_404 +from synapse.media._base import parse_media_id, respond_404 if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py new file mode 100644 index 0000000000..5ebaa3b032 --- /dev/null +++ b/synapse/rest/media/media_repository_resource.py @@ -0,0 +1,93 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse.config._base import ConfigError +from synapse.http.server import UnrecognizedRequestResource + +from .config_resource import MediaConfigResource +from .download_resource import DownloadResource +from .preview_url_resource import PreviewUrlResource +from .thumbnail_resource import ThumbnailResource +from .upload_resource import UploadResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class MediaRepositoryResource(UnrecognizedRequestResource): + """File uploading and downloading. + + Uploads are POSTed to a resource which returns a token which is used to GET + the download:: + + => POST /_matrix/media/r0/upload HTTP/1.1 + Content-Type: <media-type> + Content-Length: <content-length> + + <media> + + <= HTTP/1.1 200 OK + Content-Type: application/json + + { "content_uri": "mxc://<server-name>/<media-id>" } + + => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: <media-type> + Content-Disposition: attachment;filename=<upload-filename> + + <media> + + Clients can get thumbnails by supplying a desired width and height and + thumbnailing method:: + + => GET /_matrix/media/r0/thumbnail/<server_name> + /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: image/jpeg or image/png + + <thumbnail> + + The thumbnail methods are "crop" and "scale". "scale" tries to return an + image where either the width or the height is smaller than the requested + size. The client should then scale and letterbox the image if it needs to + fit within a given rectangle. "crop" tries to return an image where the + width and height are close to the requested size and the aspect matches + the requested size. The client should scale the image if it needs to fit + within a given rectangle. + """ + + def __init__(self, hs: "HomeServer"): + # If we're not configured to use it, raise if we somehow got here. + if not hs.config.media.can_load_media_repo: + raise ConfigError("Synapse is not configured to use a media repo.") + + super().__init__() + media_repo = hs.get_media_repository() + + self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"download", DownloadResource(hs, media_repo)) + self.putChild( + b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + ) + if hs.config.media.url_preview_enabled: + self.putChild( + b"preview_url", + PreviewUrlResource(hs, media_repo, media_repo.media_storage), + ) + self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py index a8f6fd6b35..7ada728757 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/preview_url_resource.py @@ -40,21 +40,19 @@ from synapse.http.server import ( from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.media._base import FileInfo, get_filename_from_headers +from synapse.media.media_storage import MediaStorage +from synapse.media.oembed import OEmbedProvider +from synapse.media.preview_html import decode_body, parse_html_to_open_graph from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.rest.media.v1._base import get_filename_from_headers -from synapse.rest.media.v1.media_storage import MediaStorage -from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string -from ._base import FileInfo - if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -163,6 +161,10 @@ class PreviewUrlResource(DirectServeJsonResource): 7. Stores the result in the database cache. 4. Returns the result. + If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or + image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole + does not fail. As much information as possible is returned. + The in-memory cache expires after 1 hour. Expired entries in the database cache (and their associated media files) are @@ -364,16 +366,25 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._handle_url( - oembed_url, user, allow_data_urls=True - ) - ( - og_from_oembed, - author_name, - expiration_ms, - ) = await self._handle_oembed_response( - url, oembed_info, expiration_ms - ) + try: + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) + except Exception as e: + # Fetching the oEmbed info failed, don't block the entire URL preview. + logger.warning( + "oEmbed fetch failed during URL preview: %s errored with %s", + oembed_url, + e, + ) + else: + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( + url, oembed_info, expiration_ms + ) # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 5f725c7600..4ee2a0dbda 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -27,9 +27,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import MediaStorage - -from ._base import ( +from synapse.media._base import ( FileInfo, ThumbnailInfo, parse_media_id, @@ -37,9 +35,10 @@ from ._base import ( respond_with_file, respond_with_responder, ) +from synapse.media.media_storage import MediaStorage if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -69,7 +68,8 @@ class ThumbnailResource(DirectServeJsonResource): width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") - m_type = parse_string(request, "type", "image/png") + # TODO Parse the Accept header to get an prioritised list of thumbnail types. + m_type = "image/png" if server_name == self.server_name: if self.dynamic_thumbnails: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/upload_resource.py index 97548b54e5..697348613b 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -20,10 +20,10 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_bytes_from_args from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import SpamMediaException +from synapse.media.media_storage import SpamMediaException if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index d30878f704..88427a5737 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,5 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,466 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import logging -import os -import urllib -from types import TracebackType -from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type - -import attr - -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender -from twisted.web.server import Request - -from synapse.api.errors import Codes, SynapseError, cs_error -from synapse.http.server import finish_request, respond_with_json -from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii, parse_and_validate_server_name - -logger = logging.getLogger(__name__) - -# list all text content types that will have the charset default to UTF-8 when -# none is given -TEXT_CONTENT_TYPES = [ - "text/css", - "text/csv", - "text/html", - "text/calendar", - "text/plain", - "text/javascript", - "application/json", - "application/ld+json", - "application/rtf", - "image/svg+xml", - "text/xml", -] - - -def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: - """Parses the server name, media ID and optional file name from the request URI - - Also performs some rough validation on the server name. - - Args: - request: The `Request`. - - Returns: - A tuple containing the parsed server name, media ID and optional file name. - - Raises: - SynapseError(404): if parsing or validation fail for any reason - """ - try: - # The type on postpath seems incorrect in Twisted 21.2.0. - postpath: List[bytes] = request.postpath # type: ignore - assert postpath - - # This allows users to append e.g. /test.png to the URL. Useful for - # clients that parse the URL to see content type. - server_name_bytes, media_id_bytes = postpath[:2] - server_name = server_name_bytes.decode("utf-8") - media_id = media_id_bytes.decode("utf8") - - # Validate the server name, raising if invalid - parse_and_validate_server_name(server_name) - - file_name = None - if len(postpath) > 2: - try: - file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) - except UnicodeDecodeError: - pass - return server_name, media_id, file_name - except Exception: - raise SynapseError( - 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN - ) - - -def respond_404(request: SynapseRequest) -> None: - respond_with_json( - request, - 404, - cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), - send_cors=True, - ) - - -async def respond_with_file( - request: SynapseRequest, - media_type: str, - file_path: str, - file_size: Optional[int] = None, - upload_name: Optional[str] = None, -) -> None: - logger.debug("Responding with %r", file_path) - - if os.path.isfile(file_path): - if file_size is None: - stat = os.stat(file_path) - file_size = stat.st_size - - add_file_headers(request, media_type, file_size, upload_name) - - with open(file_path, "rb") as f: - await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) - - finish_request(request) - else: - respond_404(request) - - -def add_file_headers( - request: Request, - media_type: str, - file_size: Optional[int], - upload_name: Optional[str], -) -> None: - """Adds the correct response headers in preparation for responding with the - media. - - Args: - request - media_type: The media/content type. - file_size: Size in bytes of the media, if known. - upload_name: The name of the requested file, if any. - """ - - def _quote(x: str) -> str: - return urllib.parse.quote(x.encode("utf-8")) - - # Default to a UTF-8 charset for text content types. - # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' - if media_type.lower() in TEXT_CONTENT_TYPES: - content_type = media_type + "; charset=UTF-8" - else: - content_type = media_type - - request.setHeader(b"Content-Type", content_type.encode("UTF-8")) - if upload_name: - # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. - # - # `filename` is defined to be a `value`, which is defined by RFC2616 - # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` - # is (essentially) a single US-ASCII word, and a `quoted-string` is a - # US-ASCII string surrounded by double-quotes, using backslash as an - # escape character. Note that %-encoding is *not* permitted. - # - # `filename*` is defined to be an `ext-value`, which is defined in - # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, - # where `value-chars` is essentially a %-encoded string in the given charset. - # - # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 - # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 - # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 - - # We avoid the quoted-string version of `filename`, because (a) synapse didn't - # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we - # may as well just do the filename* version. - if _can_encode_filename_as_token(upload_name): - disposition = "inline; filename=%s" % (upload_name,) - else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) - - request.setHeader(b"Content-Disposition", disposition.encode("ascii")) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - if file_size is not None: - request.setHeader(b"Content-Length", b"%d" % (file_size,)) - - # Tell web crawlers to not index, archive, or follow links in media. This - # should help to prevent things in the media repo from showing up in web - # search results. - request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") - - -# separators as defined in RFC2616. SP and HT are handled separately. -# see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = { - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", -} - - -def _can_encode_filename_as_token(x: str) -> bool: - for c in x: - # from RFC2616: - # - # token = 1*<any CHAR except CTLs or separators> - # - # separators = "(" | ")" | "<" | ">" | "@" - # | "," | ";" | ":" | "\" | <"> - # | "/" | "[" | "]" | "?" | "=" - # | "{" | "}" | SP | HT - # - # CHAR = <any US-ASCII character (octets 0 - 127)> - # - # CTL = <any US-ASCII control character - # (octets 0 - 31) and DEL (127)> - # - if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: - return False - return True - - -async def respond_with_responder( - request: SynapseRequest, - responder: "Optional[Responder]", - media_type: str, - file_size: Optional[int], - upload_name: Optional[str] = None, -) -> None: - """Responds to the request with given responder. If responder is None then - returns 404. - - Args: - request - responder - media_type: The media/content type. - file_size: Size in bytes of the media. If not known it should be None - upload_name: The name of the requested file, if any. - """ - if not responder: - respond_404(request) - return - - # If we have a responder we *must* use it as a context manager. - with responder: - if request._disconnected: - logger.warning( - "Not sending response to request %s, already disconnected.", request - ) - return - - logger.debug("Responding to media request with responder %s", responder) - add_file_headers(request, media_type, file_size, upload_name) - try: - - await responder.write_to_consumer(request) - except Exception as e: - # The majority of the time this will be due to the client having gone - # away. Unfortunately, Twisted simply throws a generic exception at us - # in that case. - logger.warning("Failed to write to consumer: %s %s", type(e), e) - - # Unregister the producer, if it has one, so Twisted doesn't complain - if request.producer: - request.unregisterProducer() - - finish_request(request) - - -class Responder: - """Represents a response that can be streamed to the requester. - - Responder is a context manager which *must* be used, so that any resources - held can be cleaned up. - """ - - def write_to_consumer(self, consumer: IConsumer) -> Awaitable: - """Stream response into consumer - - Args: - consumer: The consumer to stream into. - - Returns: - Resolves once the response has finished being written - """ - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - pass - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThumbnailInfo: - """Details about a generated thumbnail.""" - - width: int - height: int - method: str - # Content type of thumbnail, e.g. image/png - type: str - # The size of the media file, in bytes. - length: Optional[int] = None - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class FileInfo: - """Details about a requested/uploaded file.""" - - # The server name where the media originated from, or None if local. - server_name: Optional[str] - # The local ID of the file. For local files this is the same as the media_id - file_id: str - # If the file is for the url preview cache - url_cache: bool = False - # Whether the file is a thumbnail or not. - thumbnail: Optional[ThumbnailInfo] = None - - # The below properties exist to maintain compatibility with third-party modules. - @property - def thumbnail_width(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.width - - @property - def thumbnail_height(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.height - - @property - def thumbnail_method(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.method - - @property - def thumbnail_type(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.type - - @property - def thumbnail_length(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.length - - -def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: - """ - Get the filename of the downloaded file by inspecting the - Content-Disposition HTTP header. - - Args: - headers: The HTTP request headers. - - Returns: - The filename, or None. - """ - content_disposition = headers.get(b"Content-Disposition", [b""]) - - # No header, bail out. - if not content_disposition[0]: - return None - - _, params = _parse_header(content_disposition[0]) - - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get(b"filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith(b"utf-8''"): - upload_name_utf8 = upload_name_utf8[7:] - # We have a filename*= section. This MUST be ASCII, and any UTF-8 - # bytes are %-quoted. - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get(b"filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii.decode("ascii") - - # This may be None here, indicating we did not find a matching name. - return upload_name - - -def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: - """Parse a Content-type like header. - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - line: header to be parsed - - Returns: - The main content-type, followed by the parameter dictionary - """ - parts = _parseparam(b";" + line) - key = next(parts) - pdict = {} - for p in parts: - i = p.find(b"=") - if i >= 0: - name = p[:i].strip().lower() - value = p[i + 1 :].strip() - - # strip double-quotes - if len(value) >= 2 and value[0:1] == value[-1:] == b'"': - value = value[1:-1] - value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') - pdict[name] = value - - return key, pdict - - -def _parseparam(s: bytes) -> Generator[bytes, None, None]: - """Generator which splits the input on ;, respecting double-quoted sequences - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - s: header to be parsed - - Returns: - The split input - """ - while s[:1] == b";": - s = s[1:] - - # look for the next ; - end = s.find(b";") - - # if there is an odd number of " marks between here and the next ;, skip to the - # next ; instead - while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: - end = s.find(b";", end + 1) - - if end < 0: - end = len(s) - f = s[:end] - yield f.strip() - s = s[end:] +# This exists purely for backwards compatibility with media providers and spam checkers. +from synapse.media._base import FileInfo, Responder # noqa: F401 diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index db25848744..11b0e8e231 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,364 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import os -import shutil -from types import TracebackType -from typing import ( - IO, - TYPE_CHECKING, - Any, - Awaitable, - BinaryIO, - Callable, - Generator, - Optional, - Sequence, - Tuple, - Type, -) - -import attr - -from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender - -import synapse -from synapse.api.errors import NotFoundError -from synapse.logging.context import defer_to_thread, make_deferred_yieldable -from synapse.util import Clock -from synapse.util.file_consumer import BackgroundFileConsumer - -from ._base import FileInfo, Responder -from .filepath import MediaFilePaths - -if TYPE_CHECKING: - from synapse.rest.media.v1.storage_provider import StorageProvider - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class MediaStorage: - """Responsible for storing/fetching files from local sources. - - Args: - hs - local_media_directory: Base path where we store media on disk - filepaths - storage_providers: List of StorageProvider that are used to fetch and store files. - """ - - def __init__( - self, - hs: "HomeServer", - local_media_directory: str, - filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProvider"], - ): - self.hs = hs - self.reactor = hs.get_reactor() - self.local_media_directory = local_media_directory - self.filepaths = filepaths - self.storage_providers = storage_providers - self.spam_checker = hs.get_spam_checker() - self.clock = hs.get_clock() - - async def store_file(self, source: IO, file_info: FileInfo) -> str: - """Write `source` to the on disk media store, and also any other - configured storage providers - - Args: - source: A file like object that should be written - file_info: Info about the file to store - - Returns: - the file path written to in the primary media store - """ - - with self.store_into_file(file_info) as (f, fname, finish_cb): - # Write to the main repository - await self.write_to_file(source, f) - await finish_cb() - - return fname - - async def write_to_file(self, source: IO, output: IO) -> None: - """Asynchronously write the `source` to `output`.""" - await defer_to_thread(self.reactor, _write_file_synchronously, source, output) - - @contextlib.contextmanager - def store_into_file( - self, file_info: FileInfo - ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: - """Context manager used to get a file like object to write into, as - described by file_info. - - Actually yields a 3-tuple (file, fname, finish_cb), where file is a file - like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns an awaitable. - - fname can be used to read the contents from after upload, e.g. to - generate thumbnails. - - finish_cb must be called and waited on after the file has been - successfully been written to. Should not be called if there was an - error. - - Args: - file_info: Info about the file to store - - Example: - - with media_storage.store_into_file(info) as (f, fname, finish_cb): - # .. write into f ... - await finish_cb() - """ - - path = self._file_info_to_path(file_info) - fname = os.path.join(self.local_media_directory, path) - - dirname = os.path.dirname(fname) - os.makedirs(dirname, exist_ok=True) - - finished_called = [False] - - try: - with open(fname, "wb") as f: - - async def finish() -> None: - # Ensure that all writes have been flushed and close the - # file. - f.flush() - f.close() - - spam_check = await self.spam_checker.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info - ) - if spam_check != synapse.module_api.NOT_SPAM: - logger.info("Blocking media due to spam checker") - # Note that we'll delete the stored media, due to the - # try/except below. The media also won't be stored in - # the DB. - # We currently ignore any additional field returned by - # the spam-check API. - raise SpamMediaException(errcode=spam_check[0]) - - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - - yield f, fname, finish - except Exception as e: - try: - os.remove(fname) - except Exception: - pass - - raise e from None - - if not finished_called: - raise Exception("Finished callback not called") - - async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: - """Attempts to fetch media described by file_info from the local cache - and configured storage providers. - - Args: - file_info - - Returns: - Returns a Responder if the file was found, otherwise None. - """ - paths = [self._file_info_to_path(file_info)] - - # fallback for remote thumbnails with no method in the filename - if file_info.thumbnail and file_info.server_name: - paths.append( - self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - ) - - for path in paths: - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) - logger.debug("local file %s did not exist", local_path) - - for provider in self.storage_providers: - for path in paths: - res: Any = await provider.fetch(path, file_info) - if res: - logger.debug("Streaming %s from %s", path, provider) - return res - logger.debug("%s not found on %s", path, provider) - - return None - - async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: - """Ensures that the given file is in the local cache. Attempts to - download it from storage providers if it isn't. - - Args: - file_info - - Returns: - Full path to local file - """ - path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return local_path - - # Fallback for paths without method names - # Should be removed in the future - if file_info.thumbnail and file_info.server_name: - legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - legacy_local_path = os.path.join(self.local_media_directory, legacy_path) - if os.path.exists(legacy_local_path): - return legacy_local_path - - dirname = os.path.dirname(local_path) - os.makedirs(dirname, exist_ok=True) - - for provider in self.storage_providers: - res: Any = await provider.fetch(path, file_info) - if res: - with res: - consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.reactor - ) - await res.write_to_consumer(consumer) - await consumer.wait() - return local_path - - raise NotFoundError() - - def _file_info_to_path(self, file_info: FileInfo) -> str: - """Converts file_info into a relative path. - - The path is suitable for storing files under a directory, e.g. used to - store files on local FS under the base media repository directory. - """ - if file_info.url_cache: - if file_info.thumbnail: - return self.filepaths.url_cache_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.url_cache_filepath_rel(file_info.file_id) - - if file_info.server_name: - if file_info.thumbnail: - return self.filepaths.remote_media_thumbnail_rel( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.remote_media_filepath_rel( - file_info.server_name, file_info.file_id - ) - - if file_info.thumbnail: - return self.filepaths.local_media_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.local_media_filepath_rel(file_info.file_id) - - -def _write_file_synchronously(source: IO, dest: IO) -> None: - """Write `source` to the file like `dest` synchronously. Should be called - from a thread. - - Args: - source: A file like object that's to be written - dest: A file like object to be written to - """ - source.seek(0) # Ensure we read from the start of the file - shutil.copyfileobj(source, dest) - - -class FileResponder(Responder): - """Wraps an open file that can be sent to a request. - - Args: - open_file: A file like object to be streamed ot the client, - is closed when finished streaming. - """ - - def __init__(self, open_file: IO): - self.open_file = open_file - - def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - FileSender().beginFileTransfer(self.open_file, consumer) - ) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.open_file.close() - - -class SpamMediaException(NotFoundError): - """The media was blocked by a spam checker, so we simply 404 the request (in - the same way as if it was quarantined). - """ - - -@attr.s(slots=True, auto_attribs=True) -class ReadableFileWrapper: - """Wrapper that allows reading a file in chunks, yielding to the reactor, - and writing to a callback. - - This is simplified `FileSender` that takes an IO object rather than an - `IConsumer`. - """ - - CHUNK_SIZE = 2**14 - - clock: Clock - path: str - - async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: - """Reads the file in chunks and calls the callback with each chunk.""" - - with open(self.path, "rb") as file: - while True: - chunk = file.read(self.CHUNK_SIZE) - if not chunk: - break - - callback(chunk) +# - # We yield to the reactor by sleeping for 0 seconds. - await self.clock.sleep(0) +# This exists purely for backwards compatibility with spam checkers. +from synapse.media.media_storage import ReadableFileWrapper # noqa: F401 diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 1c9b71d69c..d7653f30ae 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,171 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import abc -import logging -import os -import shutil -from typing import TYPE_CHECKING, Callable, Optional - -from synapse.config._base import Config -from synapse.logging.context import defer_to_thread, run_in_background -from synapse.util.async_helpers import maybe_awaitable - -from ._base import FileInfo, Responder -from .media_storage import FileResponder - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class StorageProvider(metaclass=abc.ABCMeta): - """A storage provider is a service that can store uploaded media and - retrieve them. - """ - - @abc.abstractmethod - async def store_file(self, path: str, file_info: FileInfo) -> None: - """Store the file described by file_info. The actual contents can be - retrieved by reading the file in file_info.upload_path. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - """ - - @abc.abstractmethod - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """Attempt to fetch the file described by file_info and stream it - into writer. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - - Returns: - Returns a Responder if the provider has the file, otherwise returns None. - """ - - -class StorageProviderWrapper(StorageProvider): - """Wraps a storage provider and provides various config options - - Args: - backend: The storage provider to wrap. - store_local: Whether to store new local files or not. - store_synchronous: Whether to wait for file to be successfully - uploaded, or todo the upload in the background. - store_remote: Whether remote media should be uploaded - """ - - def __init__( - self, - backend: StorageProvider, - store_local: bool, - store_synchronous: bool, - store_remote: bool, - ): - self.backend = backend - self.store_local = store_local - self.store_synchronous = store_synchronous - self.store_remote = store_remote - - def __str__(self) -> str: - return "StorageProviderWrapper[%s]" % (self.backend,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - if not file_info.server_name and not self.store_local: - return None - - if file_info.server_name and not self.store_remote: - return None - - if file_info.url_cache: - # The URL preview cache is short lived and not worth offloading or - # backing up. - return None - - if self.store_synchronous: - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore - else: - # TODO: Handle errors. - async def store() -> None: - try: - return await maybe_awaitable( - self.backend.store_file(path, file_info) - ) - except Exception: - logger.exception("Error storing file") - - run_in_background(store) - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - if file_info.url_cache: - # Files in the URL preview cache definitely aren't stored here, - # so avoid any potentially slow I/O or network access. - return None - - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - return await maybe_awaitable(self.backend.fetch(path, file_info)) - - -class FileStorageProviderBackend(StorageProvider): - """A storage provider that stores files in a directory on a filesystem. - - Args: - hs - config: The config returned by `parse_config`. - """ - - def __init__(self, hs: "HomeServer", config: str): - self.hs = hs - self.cache_directory = hs.config.media.media_store_path - self.base_directory = config - - def __str__(self) -> str: - return "FileStorageProviderBackend[%s]" % (self.base_directory,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - """See StorageProvider.store_file""" - - primary_fname = os.path.join(self.cache_directory, path) - backup_fname = os.path.join(self.base_directory, path) - - dirname = os.path.dirname(backup_fname) - os.makedirs(dirname, exist_ok=True) - - # mypy needs help inferring the type of the second parameter, which is generic - shutil_copyfile: Callable[[str, str], str] = shutil.copyfile - await defer_to_thread( - self.hs.get_reactor(), - shutil_copyfile, - primary_fname, - backup_fname, - ) - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """See StorageProvider.fetch""" - - backup_fname = os.path.join(self.base_directory, path) - if os.path.isfile(backup_fname): - return FileResponder(open(backup_fname, "rb")) - - return None - - @staticmethod - def parse_config(config: dict) -> str: - """Called on startup to parse config supplied. This should parse - the config and raise if there is a problem. - - The returned value is passed into the constructor. - - In this case we only care about a single param, the directory, so let's - just pull that out. - """ - return Config.ensure_directory(config["directory"]) +# This exists purely for backwards compatibility with media providers. +from synapse.media.storage_provider import StorageProvider # noqa: F401 diff --git a/synapse/server.py b/synapse/server.py index e5a3475247..df80fc1beb 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,6 +105,7 @@ from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient +from synapse.media.media_repository import MediaRepository from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi from synapse.notifier import Notifier, ReplicationNotifier @@ -115,10 +116,7 @@ from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream -from synapse.rest.media.v1.media_repository import ( - MediaRepository, - MediaRepositoryResource, -) +from synapse.rest.media.media_repository_resource import MediaRepositoryResource from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender from synapse.server_notices.worker_server_notices_sender import ( @@ -745,7 +743,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer(self.config.experimental.msc3925_inhibit_edit) + return EventClientSerializer() @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 564e3705c2..9732dbdb6e 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -178,7 +178,7 @@ class ServerNoticesManager: "avatar_url": self._config.servernotices.server_notices_mxid_avatar_url, } - info, _ = await self._room_creation_handler.create_room( + room_id, _, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, @@ -188,7 +188,6 @@ class ServerNoticesManager: ratelimit=False, creator_join_profile=join_profile, ) - room_id = info["room_id"] self.maybe_get_notice_room_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 837dc7646e..dc3948c170 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -43,7 +43,7 @@ from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore from .events_forward_extremities import EventForwardExtremitiesStore -from .filtering import FilteringStore +from .filtering import FilteringWorkerStore from .keys import KeyStore from .lock import LockStore from .media_repository import MediaRepositoryStore @@ -99,7 +99,7 @@ class DataStore( EventFederationStore, MediaRepositoryStore, RejectionsStore, - FilteringStore, + FilteringWorkerStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 95567826f2..a9843f6e17 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ): super().__init__(database, db_conn, hs) - # `_can_write_to_account_data` indicates whether the current worker is allowed - # to write account data. A value of `True` implies that `_account_data_id_gen` - # is an `AbstractStreamIdGenerator` and not just a tracker. - self._account_data_id_gen: AbstractStreamIdTracker self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data ) + self._account_data_id_gen: AbstractStreamIdGenerator + if isinstance(database.engine, PostgresEngine): self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -237,6 +234,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: return None + async def get_latest_stream_id_for_global_account_data_by_type_for_user( + self, user_id: str, data_type: str + ) -> Optional[int]: + """ + Returns: + The stream ID of the account data, + or None if there is no such account data. + """ + + def get_latest_stream_id_for_global_account_data_by_type_for_user_txn( + txn: LoggingTransaction, + ) -> Optional[int]: + sql = """ + SELECT stream_id FROM account_data + WHERE user_id = ? AND account_data_type = ? + ORDER BY stream_id DESC + LIMIT 1 + """ + txn.execute(sql, (user_id, data_type)) + + row = txn.fetchone() + if row: + return row[0] + else: + return None + + return await self.db_pool.runInteraction( + "get_latest_stream_id_for_global_account_data_by_type_for_user", + get_latest_stream_id_for_global_account_data_by_type_for_user_txn, + ) + @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str @@ -527,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -554,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def remove_account_data_for_room( self, user_id: str, room_id: str, account_data_type: str - ) -> Optional[int]: + ) -> int: """Delete the room account data for the user of a given type. Args: @@ -567,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) data to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_room_txn( txn: LoggingTransaction, next_id: int @@ -606,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) next_id, ) - if not row_updated: - return None - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_room_account_data_for_user.invalidate((user_id,)) - self.get_account_data_for_room.invalidate((user_id, room_id)) - self.get_account_data_for_room_and_type.prefill( - (user_id, room_id, account_data_type), {} - ) + if row_updated: + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_room_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), {} + ) return self._account_data_id_gen.get_current_token() @@ -632,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -722,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self, user_id: str, account_data_type: str, - ) -> Optional[int]: + ) -> int: """ Delete a single piece of user account data by type. @@ -739,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_user_txn( txn: LoggingTransaction, next_id: int @@ -809,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) next_id, ) - if not row_updated: - return None - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_global_account_data_for_user.invalidate((user_id,)) - self.get_global_account_data_by_type_for_user.prefill( - (user_id, account_data_type), {} - ) + if row_updated: + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_global_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.prefill( + (user_id, account_data_type), {} + ) return self._account_data_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 5b66431691..096dec7f87 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_aggregation_groups_for_event", (relates_to,) - ) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 8e61aba454..0d75d9739a 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - for (user_id, messages_by_device) in edu["messages"].items(): - for (device_id, msg) in messages_by_device.items(): + for user_id, messages_by_device in edu["messages"].items(): + for device_id, msg in messages_by_device.items(): with start_active_span("store_outgoing_to_device_message"): set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) @@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, ) -> Tuple[int, bool]: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1ca66d57d4..5503621ad6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -52,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, StreamIdGenerator, ) from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key @@ -91,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._device_list_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "device_lists_stream", @@ -512,7 +511,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): results.append(("org.matrix.signing_key_update", result)) if issue_8631_logger.isEnabledFor(logging.DEBUG): - for (user_id, edu) in results: + for user_id, edu in results: issue_8631_logger.debug( "device update to %s for %s from %s to %s: %s", destination, @@ -712,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The new stream ID. """ - # TODO: this looks like it's _writing_. Should this be on DeviceStore rather - # than DeviceWorkerStore? - async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1316,7 +1313,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) """ count = 0 - for (destination, user_id, stream_id, device_id) in rows: + for destination, user_id, stream_id, device_id in rows: txn.execute( delete_sql, (destination, user_id, stream_id, stream_id, device_id) ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 6240f9a75e..9f8d2e4bea 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No backup with that version exists") values = [] - for (room_id, session_id, room_key) in room_keys: + for room_id, session_id, room_key in room_keys: values.append( ( user_id, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2c2d145666..b9c39b1718 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # add each cross-signing signature to the correct device in the result dict. - for (user_id, key_id, device_id, signature) in cross_sigs_result: + for user_id, key_id, device_id, signature in cross_sigs_result: target_device_result = result[user_id][device_id] # We've only looked up cross-signatures for non-deleted devices with key # data. @@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker # devices. user_list = [] user_device_list = [] - for (user_id, device_id) in query_list: + for user_id, device_id in query_list: if device_id is None: user_list.append(user_id) else: @@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker txn.execute(sql, query_params) - for (user_id, device_id, display_name, key_json) in txn: + for user_id, device_id, display_name, key_json in txn: assert device_id is not None if include_deleted_devices: deleted_devices.remove((user_id, device_id)) @@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_query_clauses = [] signature_query_params = [] - for (user_id, device_id) in device_query: + for user_id, device_id in device_query: signature_query_clauses.append( "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca780cca36..ff3edeb716 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas latest_events: List[str], limit: int, ) -> List[str]: - seen_events = set(earliest_events) front = set(latest_events) - seen_events event_results: List[str] = [] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7996cbb557..a8a4ed4436 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -469,7 +469,6 @@ class PersistEventsStore: txn: LoggingTransaction, events: List[EventBase], ) -> None: - # We only care about state events, so this if there are no state events. if not any(e.is_state() for e in events): return @@ -2025,10 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_relations_for_event, (redacted_relates_to,) ) - if rel_type == RelationTypes.ANNOTATION: - self.store._invalidate_cache_and_stream( - txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) - ) if rel_type == RelationTypes.REFERENCE: self.store._invalidate_cache_and_stream( txn, self.store.get_references_for_event, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 584536111d..daef3685b0 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): nbrows = 0 last_row_event_id = "" - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) @@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): results = list(txn) # (event_id, parent_id, rel_type) for each relation relations_to_insert: List[Tuple[str, str, str]] = [] - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) except Exception as e: @@ -1220,9 +1220,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] ) self._invalidate_cache_and_stream( # type: ignore[attr-defined] - txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined] - ) - self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6d0ef10258..20b7a68362 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdTracker - self._backfill_id_gen: AbstractStreamIdTracker + self._stream_id_gen: AbstractStreamIdGenerator + self._backfill_id_gen: AbstractStreamIdGenerator if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. @@ -1493,7 +1492,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(redactions_sql + clause, args) - for (redacter, redacted) in txn: + for redacter, redacted in txn: d = event_dict.get(redacted) if d: d.redactions.append(redacter) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 12f3b601f1..8e57c8e5a0 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, StoreError, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict @@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore): return db_to_json(def_json) - -class FilteringStore(FilteringWorkerStore): async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) @@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore): return filter_id - return await self.db_pool.runInteraction("add_user_filter", _do_txn) + attempts = 0 + while True: + # Try a few times. + # This is technically needed if a user tries to create two filters at once, + # leading to two concurrent transactions. + # The failure case would be: + # - SELECT filter_id ... filter_json = ? → both transactions return no rows + # - SELECT MAX(filter_id) ... → both transactions return e.g. 5 + # - INSERT INTO ... → both transactions insert filter_id = 6 + # One of the transactions will commit. The other will get a unique key + # constraint violation error (IntegrityError). This is not the same as a + # serialisability violation, which would be automatically retried by + # `runInteraction`. + try: + return await self.db_pool.runInteraction("add_user_filter", _do_txn) + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + + if attempts >= 5: + raise StoreError(500, "Couldn't generate a filter ID.") diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index b202c5eb87..fa8be214ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[Dict[str, Any]], int]: - # Set ordering order_by_column = MediaSortOrder(order_by).value diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 9b2bbe060d..9f862f00c1 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, IdGenerator, StreamIdGenerator, ) @@ -118,7 +117,7 @@ class PushRulesWorkerStore( # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._push_rules_stream_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "push_rules_stream", diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df53e726e6..9a24f7a655 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -36,7 +36,6 @@ from synapse.storage.database import ( ) from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, StreamIdGenerator, ) from synapse.types import JsonDict @@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._pushers_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "pushers", @@ -344,7 +343,6 @@ class PusherWorkerStore(SQLBaseStore): last_user = progress.get("last_user", "") def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT name FROM users WHERE deactivated = ? and name > ? @@ -392,7 +390,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, access_token FROM pushers AS p LEFT JOIN access_tokens AS a ON (p.access_token = a.id) @@ -449,7 +446,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey FROM pushers AS p diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dddf49c2d5..074942b167 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -39,7 +39,7 @@ from synapse.storage.database import ( from synapse.storage.engines import PostgresEngine from synapse.storage.engines._base import IsolationLevel from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, + AbstractStreamIdGenerator, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -65,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._receipts_id_gen: AbstractStreamIdTracker + self._receipts_id_gen: AbstractStreamIdGenerator if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( @@ -768,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) - async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self._insert_linearized_receipt_txn, @@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): def _populate_receipt_event_stream_ordering_txn( txn: LoggingTransaction, ) -> bool: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 9a55e17624..717237e024 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="user_delete_threepid", ) - async def user_delete_threepids(self, user_id: str) -> None: - """Delete all threepid this user has bound - - Args: - user_id: The user id to delete all threepids of - - """ - await self.db_pool.simple_delete( - "user_threepids", - keyvalues={"user_id": user_id}, - desc="user_delete_threepids", - ) - async def add_user_bound_threepid( self, user_id: str, medium: str, address: str, id_server: str ) -> None: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index fa3266c081..bc3a83919c 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -398,143 +398,6 @@ class RelationsWorkerStore(SQLBaseStore): return result is not None @cached() - async def get_aggregation_groups_for_event( - self, event_id: str - ) -> Sequence[JsonDict]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" - ) - async def get_aggregation_groups_for_events( - self, event_ids: Collection[str] - ) -> Mapping[str, Optional[List[JsonDict]]]: - """Get a list of annotations on the given events, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happend - on an event. - - Args: - event_ids: Fetch events that relate to these event IDs. - - Returns: - A map of event IDs to a list of groups of annotations that match. - Each entry is a dict with `type`, `key` and `count` fields. - """ - # The number of entries to return per event ID. - limit = 5 - - clause, args = make_in_list_sql_clause( - self.database_engine, "relates_to_id", event_ids - ) - args.append(RelationTypes.ANNOTATION) - - sql = f""" - SELECT - relates_to_id, - annotation.type, - aggregation_key, - COUNT(DISTINCT annotation.sender) - FROM events AS annotation - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS parent ON - parent.event_id = relates_to_id - AND parent.room_id = annotation.room_id - WHERE - {clause} - AND relation_type = ? - GROUP BY relates_to_id, annotation.type, aggregation_key - ORDER BY relates_to_id, COUNT(*) DESC - """ - - def _get_aggregation_groups_for_events_txn( - txn: LoggingTransaction, - ) -> Mapping[str, List[JsonDict]]: - txn.execute(sql, args) - - result: Dict[str, List[JsonDict]] = {} - for event_id, type, key, count in cast( - List[Tuple[str, str, str, int]], txn - ): - event_results = result.setdefault(event_id, []) - - # Limit the number of results per event ID. - if len(event_results) == limit: - continue - - event_results.append({"type": type, "key": key, "count": count}) - - return result - - return await self.db_pool.runInteraction( - "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn - ) - - async def get_aggregation_groups_for_users( - self, event_ids: Collection[str], users: FrozenSet[str] - ) -> Dict[str, Dict[Tuple[str, str], int]]: - """Fetch the partial aggregations for an event for specific users. - - This is used, in conjunction with get_aggregation_groups_for_event, to - remove information from the results for ignored users. - - Args: - event_ids: Fetch events that relate to these event IDs. - users: The users to fetch information for. - - Returns: - A map of event ID to a map of (event type, aggregation key) to a - count of users. - """ - - if not users: - return {} - - events_sql, args = make_in_list_sql_clause( - self.database_engine, "relates_to_id", event_ids - ) - - users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "annotation.sender", users - ) - args.extend(users_args) - args.append(RelationTypes.ANNOTATION) - - sql = f""" - SELECT - relates_to_id, - annotation.type, - aggregation_key, - COUNT(DISTINCT annotation.sender) - FROM events AS annotation - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS parent ON - parent.event_id = relates_to_id - AND parent.room_id = annotation.room_id - WHERE {events_sql} AND {users_sql} AND relation_type = ? - GROUP BY relates_to_id, annotation.type, aggregation_key - ORDER BY relates_to_id, COUNT(*) DESC - """ - - def _get_aggregation_groups_for_users_txn( - txn: LoggingTransaction, - ) -> Dict[str, Dict[Tuple[str, str], int]]: - txn.execute(sql, args) - - result: Dict[str, Dict[Tuple[str, str], int]] = {} - for event_id, type, key, count in cast( - List[Tuple[str, str, str, int]], txn - ): - result.setdefault(event_id, {})[(type, key)] = count - - return result - - return await self.db_pool.runInteraction( - "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn - ) - - @cached() async def get_references_for_event(self, event_id: str) -> List[JsonDict]: raise NotImplementedError() diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 644bbb8878..3825bd6079 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1417,6 +1417,204 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_un_partial_stated_rooms_from_stream_txn, ) + async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: + """Retrieve an event report + + Args: + report_id: ID of reported event in database + Returns: + JSON dict of information from an event report or None if the + report does not exist. + """ + + def _get_event_report_txn( + txn: LoggingTransaction, report_id: int + ) -> Optional[Dict[str, Any]]: + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + WHERE er.id = ? + """ + + txn.execute(sql, [report_id]) + row = txn.fetchone() + + if not row: + return None + + event_report = { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": db_to_json(row[5]).get("score"), + "reason": db_to_json(row[5]).get("reason"), + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + "event_json": db_to_json(row[9]), + } + + return event_report + + return await self.db_pool.runInteraction( + "get_event_report", _get_event_report_txn, report_id + ) + + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: Direction = Direction.BACKWARDS, + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + Tuple of: + json list of event reports + total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: + filters = [] + args: List[object] = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + # We join on room_stats_state despite not using any columns from it + # because the join can influence the number of rows returned; + # e.g. a room that doesn't have state, maybe because it was deleted. + # The query returning the total count should be consistent with + # the query returning the results. + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + JOIN room_stats_state ON room_stats_state.room_id = er.room_id + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = cast(Tuple[int], txn.fetchone())[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, + order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + + event_reports = [] + for row in txn: + try: + s = db_to_json(row[5]).get("score") + r = db_to_json(row[5]).get("reason") + except Exception: + logger.error("Unable to parse json from event_reports: %s", row[0]) + continue + event_reports.append( + { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": s, + "reason": r, + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + } + ) + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + + async def delete_event_report(self, report_id: int) -> bool: + """Remove an event report from database. + + Args: + report_id: Report to delete + + Returns: + Whether the report was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + table="event_reports", + keyvalues={"id": report_id}, + desc="delete_event_report", + ) + except StoreError: + # Deletion failed because report does not exist + return False + + return True + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -2139,7 +2337,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): reason: Optional[str], content: JsonDict, received_ts: int, - ) -> None: + ) -> int: + """Add an event report + + Args: + room_id: Room that contains the reported event. + event_id: The reported event. + user_id: User who reports the event. + reason: Description that the user specifies. + content: Report request body (score and reason). + received_ts: Time when the user submitted the report (milliseconds). + Returns: + Id of the event report. + """ next_id = self._event_reports_id_gen.get_next() await self.db_pool.simple_insert( table="event_reports", @@ -2154,184 +2364,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): }, desc="add_event_report", ) - - async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: - """Retrieve an event report - - Args: - report_id: ID of reported event in database - Returns: - JSON dict of information from an event report or None if the - report does not exist. - """ - - def _get_event_report_txn( - txn: LoggingTransaction, report_id: int - ) -> Optional[Dict[str, Any]]: - - sql = """ - SELECT - er.id, - er.received_ts, - er.room_id, - er.event_id, - er.user_id, - er.content, - events.sender, - room_stats_state.canonical_alias, - room_stats_state.name, - event_json.json AS event_json - FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id - JOIN event_json - ON event_json.event_id = er.event_id - JOIN room_stats_state - ON room_stats_state.room_id = er.room_id - WHERE er.id = ? - """ - - txn.execute(sql, [report_id]) - row = txn.fetchone() - - if not row: - return None - - event_report = { - "id": row[0], - "received_ts": row[1], - "room_id": row[2], - "event_id": row[3], - "user_id": row[4], - "score": db_to_json(row[5]).get("score"), - "reason": db_to_json(row[5]).get("reason"), - "sender": row[6], - "canonical_alias": row[7], - "name": row[8], - "event_json": db_to_json(row[9]), - } - - return event_report - - return await self.db_pool.runInteraction( - "get_event_report", _get_event_report_txn, report_id - ) - - async def get_event_reports_paginate( - self, - start: int, - limit: int, - direction: Direction = Direction.BACKWARDS, - user_id: Optional[str] = None, - room_id: Optional[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: - """Retrieve a paginated list of event reports - - Args: - start: event offset to begin the query from - limit: number of rows to retrieve - direction: Whether to fetch the most recent first (backwards) or the - oldest first (forwards) - user_id: search for user_id. Ignored if user_id is None - room_id: search for room_id. Ignored if room_id is None - Returns: - Tuple of: - json list of event reports - total number of event reports matching the filter criteria - """ - - def _get_event_reports_paginate_txn( - txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: - filters = [] - args: List[object] = [] - - if user_id: - filters.append("er.user_id LIKE ?") - args.extend(["%" + user_id + "%"]) - if room_id: - filters.append("er.room_id LIKE ?") - args.extend(["%" + room_id + "%"]) - - if direction == Direction.BACKWARDS: - order = "DESC" - else: - order = "ASC" - - where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - - # We join on room_stats_state despite not using any columns from it - # because the join can influence the number of rows returned; - # e.g. a room that doesn't have state, maybe because it was deleted. - # The query returning the total count should be consistent with - # the query returning the results. - sql = """ - SELECT COUNT(*) as total_event_reports - FROM event_reports AS er - JOIN room_stats_state ON room_stats_state.room_id = er.room_id - {} - """.format( - where_clause - ) - txn.execute(sql, args) - count = cast(Tuple[int], txn.fetchone())[0] - - sql = """ - SELECT - er.id, - er.received_ts, - er.room_id, - er.event_id, - er.user_id, - er.content, - events.sender, - room_stats_state.canonical_alias, - room_stats_state.name - FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id - JOIN room_stats_state - ON room_stats_state.room_id = er.room_id - {where_clause} - ORDER BY er.received_ts {order} - LIMIT ? - OFFSET ? - """.format( - where_clause=where_clause, - order=order, - ) - - args += [limit, start] - txn.execute(sql, args) - - event_reports = [] - for row in txn: - try: - s = db_to_json(row[5]).get("score") - r = db_to_json(row[5]).get("reason") - except Exception: - logger.error("Unable to parse json from event_reports: %s", row[0]) - continue - event_reports.append( - { - "id": row[0], - "received_ts": row[1], - "room_id": row[2], - "event_id": row[3], - "user_id": row[4], - "score": s, - "reason": r, - "sender": row[6], - "canonical_alias": row[7], - "name": row[8], - } - ) - - return event_reports, count - - return await self.db_pool.runInteraction( - "get_event_reports_paginate", _get_event_reports_paginate_txn - ) + return next_id async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 3fe433f66c..a7aae661d8 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore): class SearchBackgroundUpdateStore(SearchWorkerStore): - EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" @@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index ba325d390b..ebb2ae964f 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d7b7d0c3c9..d3393d8e49 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore): insert_cols = [] qargs = [] - for (key, val) in chain( + for key, val in chain( keyvalues.items(), absolutes.items(), additive_relatives.items() ): insert_cols.append(key) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 818c46182e..ac5fbf6b86 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" + # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 6b33d809b6..6d72bd9f67 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def get_destination_rooms_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: - if direction == Direction.BACKWARDS: order = "DESC" else: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f6a6fd4079..f16a509ac4 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -14,6 +14,7 @@ import logging import re +import unicodedata from typing import ( TYPE_CHECKING, Iterable, @@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def _populate_user_directory_createtables( self, progress: JsonDict, batch_size: int ) -> int: - # Get all the rooms that we want to process. def _make_staging_area(txn: LoggingTransaction) -> None: sql = ( @@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): values={"display_name": display_name, "avatar_url": avatar_url}, ) + # The display name that goes into the database index. + index_display_name = display_name + if index_display_name is not None: + index_display_name = _filter_text_for_index(index_display_name) + if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name @@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, + index_display_name, ), ) elif isinstance(self.database_engine, Sqlite3Engine): - value = "%s %s" % (user_id, display_name) if display_name else user_id + value = ( + "%s %s" % (user_id, index_display_name) + if index_display_name + else user_id + ) self.db_pool.simple_upsert_txn( txn, table="user_directory_search", @@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return {"limited": limited, "results": results[0:limit]} +def _filter_text_for_index(text: str) -> str: + """Transforms text before it is inserted into the user directory index, or searched + for in the user directory index. + + Note that the user directory search table needs to be rebuilt whenever this function + changes. + """ + # Lowercase the text, to make searches case-insensitive. + # This is necessary for both PostgreSQL and SQLite. PostgreSQL's + # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using + # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at + # all. + text = text.lower() + + # Normalize the text. NFKC normalization has two effects: + # 1. It canonicalizes the text, ie. maps all visually identical strings to the same + # string. For example, ["e", "◌́"] is mapped to ["é"]. + # 2. It maps strings that are roughly equivalent to the same string. + # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to + # ["i", "9"]. + text = unicodedata.normalize("NFKC", text) + + # Note that nothing is done to make searches accent-insensitive. + # That could be achieved by converting to NFKD form instead (with combining accents + # split out) and filtering out combining accents using `unicodedata.combining(c)`. + # The downside of this may be noisier search results, since search terms with + # explicit accents will match characters with no accents, or completely different + # accents. + # + # text = unicodedata.normalize("NFKD", text) + # text = "".join([c for c in text if not unicodedata.combining(c)]) + + return text + + def _parse_query_sqlite(search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. @@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str: We specifically add both a prefix and non prefix matching term so that exact matches get ranked higher. """ + search_term = _filter_text_for_index(search_term) # Pull out the individual words, discarding any non-word characters. results = _parse_words(search_term) @@ -918,11 +963,21 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ - results = _parse_words(search_term) + search_term = _filter_text_for_index(search_term) + + escaped_words = [] + for word in _parse_words(search_term): + # Postgres tsvector and tsquery quoting rules: + # words potentially containing punctuation should be quoted + # and then existing quotes and backslashes should be doubled + # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY - both = " & ".join("(%s:* | %s)" % (result, result) for result in results) - exact = " & ".join("%s" % (result,) for result in results) - prefix = " & ".join("%s:*" % (result,) for result in results) + quoted_word = word.replace("'", "''").replace("\\", "\\\\") + escaped_words.append(f"'{quoted_word}'") + + both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words) + exact = " & ".join("%s" % (word,) for word in escaped_words) + prefix = " & ".join("%s:*" % (word,) for word in escaped_words) return both, exact, prefix @@ -944,6 +999,14 @@ def _parse_words(search_term: str) -> List[str]: if USE_ICU: return _parse_words_with_icu(search_term) + return _parse_words_with_regex(search_term) + + +def _parse_words_with_regex(search_term: str) -> List[str]: + """ + Break down search term into words, when we don't have ICU available. + See: `_parse_words` + """ return re.findall(r"([\w\-]+)", search_term, re.UNICODE) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index d743282f13..097dea5182 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 1a7232b276..bf4cdfdf29 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se import attr from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( + non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( + member_state, incomplete_groups_m = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) @@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) + async def store_state_deltas_for_batched( + self, + events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], + room_id: str, + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state deltas for a group of events and contexts created to be + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: the events to generate and store a state groups for + and their associated contexts + room_id: the id of the room the events were created for + prev_group: the state group of the last event persisted before the batched events + were created + """ + + def insert_deltas_group_txn( + txn: LoggingTransaction, + events_and_context: List[Tuple[EventBase, UnpersistedEventContext]], + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state groups for the provided events and contexts. + + Requires that we have the state as a delta from the last persisted state group. + + Returns: + A list of state groups + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + num_state_groups = sum( + 1 for event, _ in events_and_context if event.is_state() + ) + + state_groups = self._state_group_seq_gen.get_next_mult_txn( + txn, num_state_groups + ) + + sg_before = prev_group + state_group_iter = iter(state_groups) + for event, context in events_and_context: + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + continue + + sg_after = next(state_group_iter) + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.state_delta_due_to_event = { + (event.type, event.state_key): event.event_id + } + sg_before = sg_after + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups", + keys=("id", "room_id", "event_id"), + values=[ + (context.state_group_after_event, room_id, event.event_id) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_group_edges", + keys=("state_group", "prev_state_group"), + values=[ + ( + context.state_group_after_event, + context.state_group_before_event, + ) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) + for event, context in events_and_context + if context.state_delta_due_to_event is not None + for key, state_id in context.state_delta_due_to_event.items() + ], + ) + return events_and_context + + return await self.db_pool.runInteraction( + "store_state_deltas_for_batched.insert_deltas_group", + insert_deltas_group_txn, + events_and_context, + prev_group, + ) + async def store_state_group( self, event_id: str, diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a182e8a098..d1ccb7390a 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -25,7 +25,7 @@ try: except ImportError: class PostgresEngine(BaseDatabaseEngine): # type: ignore[no-redef] - def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] + def __new__(cls, *args: object, **kwargs: object) -> NoReturn: raise RuntimeError( f"Cannot create {cls.__name__} -- psycopg2 module is not installed" ) @@ -36,7 +36,7 @@ try: except ImportError: class Sqlite3Engine(BaseDatabaseEngine): # type: ignore[no-redef] - def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] + def __new__(cls, *args: object, **kwargs: object) -> NoReturn: raise RuntimeError( f"Cannot create {cls.__name__} -- sqlite3 module is not installed" ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6c335a9315..2a1c6fa31b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -563,7 +563,7 @@ def _apply_module_schemas( """ # This is the old way for password_auth_provider modules to make changes # to the database. This should instead be done using the module API - for (mod, _config) in config.authproviders.password_providers: + for mod, _config in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) @@ -591,7 +591,7 @@ def _apply_module_schema_files( (modname,), ) applied_deltas = {d for d, in cur} - for (name, stream) in names_and_streams: + for name, stream in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 0031df1e06..56a0048539 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from types import TracebackType -from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) from typing_extensions import Protocol @@ -112,15 +123,35 @@ class DBAPI2Module(Protocol): # extends from this hierarchy. See # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions # https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE - Warning: Type[Exception] - Error: Type[Exception] + # + # Note: rather than + # x: T + # we write + # @property + # def x(self) -> T: ... + # which expresses that the protocol attribute `x` is read-only. The mypy docs + # https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected + # explain why this is necessary for safety. TL;DR: we shouldn't be able to write + # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 . + @property + def Warning(self) -> Type[Exception]: + ... + + @property + def Error(self) -> Type[Exception]: + ... # Errors are divided into `InterfaceError`s (something went wrong in the database # driver) and `DatabaseError`s (something went wrong in the database). These are # both subclasses of `Error`, but we can't currently express this in type # annotations due to https://github.com/python/mypy/issues/8397 - InterfaceError: Type[Exception] - DatabaseError: Type[Exception] + @property + def InterfaceError(self) -> Type[Exception]: + ... + + @property + def DatabaseError(self) -> Type[Exception]: + ... # Everything below is a subclass of `DatabaseError`. @@ -128,7 +159,9 @@ class DBAPI2Module(Protocol): # - An integer was too big for its data type. # - An invalid date time was provided. # - A string contained a null code point. - DataError: Type[Exception] + @property + def DataError(self) -> Type[Exception]: + ... # Roughly: something went wrong in the database, but it's not within the application # programmer's control. Examples: @@ -138,28 +171,45 @@ class DBAPI2Module(Protocol): # - A serialisation failure occurred. # - The database ran out of resources, such as storage, memory, connections, etc. # - The database encountered an error from the operating system. - OperationalError: Type[Exception] + @property + def OperationalError(self) -> Type[Exception]: + ... # Roughly: we've given the database data which breaks a rule we asked it to enforce. # Examples: # - Stop, criminal scum! You violated the foreign key constraint # - Also check constraints, non-null constraints, etc. - IntegrityError: Type[Exception] + @property + def IntegrityError(self) -> Type[Exception]: + ... # Roughly: something went wrong within the database server itself. - InternalError: Type[Exception] + @property + def InternalError(self) -> Type[Exception]: + ... # Roughly: the application did something silly that needs to be fixed. Examples: # - We don't have permissions to do something. # - We tried to create a table with duplicate column names. # - We tried to use a reserved name. # - We referred to a column that doesn't exist. - ProgrammingError: Type[Exception] + @property + def ProgrammingError(self) -> Type[Exception]: + ... # Roughly: we've tried to do something that this database doesn't support. - NotSupportedError: Type[Exception] + @property + def NotSupportedError(self) -> Type[Exception]: + ... - def connect(self, **parameters: object) -> Connection: + # We originally wrote + # def connect(self, *args, **kwargs) -> Connection: ... + # But mypy doesn't seem to like that because sqlite3.connect takes a mandatory + # positional argument. We can't make that part of the signature though, because + # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use + # the following slightly unusual workaround. + @property + def connect(self) -> Callable[..., Connection]: ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9adff3f4f5..d2c874b9a8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -93,8 +93,11 @@ def _load_current_id( return res -class AbstractStreamIdTracker(metaclass=abc.ABCMeta): - """Tracks the "current" stream ID of a stream that may have multiple writers. +class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): + """Generates or tracks stream IDs for a stream that may have multiple writers. + + Each stream ID represents a write transaction, whose completion is tracked + so that the "current" stream ID of the stream can be determined. Stream IDs are monotonically increasing or decreasing integers representing write transactions. The "current" stream ID is the stream ID such that all transactions @@ -130,16 +133,6 @@ class AbstractStreamIdTracker(metaclass=abc.ABCMeta): """ raise NotImplementedError() - -class AbstractStreamIdGenerator(AbstractStreamIdTracker): - """Generates stream IDs for a stream that may have multiple writers. - - Each stream ID represents a write transaction, whose completion is tracked - so that the "current" stream ID of the stream can be determined. - - See `AbstractStreamIdTracker` for more details. - """ - @abc.abstractmethod def get_next(self) -> AsyncContextManager[int]: """ @@ -158,6 +151,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker): """ raise NotImplementedError() + @abc.abstractmethod + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Usage: + stream_id_gen.get_next_txn(txn) + # ... persist events ... + """ + raise NotImplementedError() + class StreamIdGenerator(AbstractStreamIdGenerator): """Generates and tracks stream IDs for a stream with a single writer. @@ -263,6 +265,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Retrieve the next stream ID from within a database transaction. + + Clean-up functions will be called when the transaction finishes. + + Args: + txn: The database transaction object. + + Returns: + The next stream ID. + """ + if not self._is_writer: + raise Exception("Tried to allocate stream ID on non-writer") + + # Get the next stream ID. + with self._lock: + self._current += self._step + next_id = self._current + + self._unfinished_ids[next_id] = next_id + + def clear_unfinished_id(id_to_clear: int) -> None: + """A function to mark processing this ID as finished""" + with self._lock: + self._unfinished_ids.pop(id_to_clear) + + # Mark this ID as finished once the database transaction itself finishes. + txn.call_after(clear_unfinished_id, next_id) + txn.call_on_exception(clear_unfinished_id, next_id) + + # Return the new ID. + return next_id + def get_current_token(self) -> int: if not self._is_writer: return self._current @@ -568,7 +604,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): """ Usage: - stream_id = stream_id_gen.get_next(txn) + stream_id = stream_id_gen.get_next_txn(txn) # ... persist event ... """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 75268cbe15..80915216de 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -205,7 +205,7 @@ class LocalSequenceGenerator(SequenceGenerator): """ Args: get_first_callback: a callback which is called on the first call to - get_next_id_txn; should return the curreent maximum id + get_next_id_txn; should return the current maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. self._callback: Optional[GetFirstCallbackType] = get_first_callback diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index c6c8a0315c..8a48ffc48d 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from abc import ABC, abstractmethod from typing import Generic, List, Optional, Tuple, TypeVar from synapse.types import StrCollection, UserID @@ -22,7 +22,8 @@ K = TypeVar("K") R = TypeVar("R") -class EventSource(Generic[K, R]): +class EventSource(ABC, Generic[K, R]): + @abstractmethod async def get_new_events( self, user: UserID, @@ -32,4 +33,4 @@ class EventSource(Generic[K, R]): is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: - ... + raise NotImplementedError() diff --git a/synapse/types/state.py b/synapse/types/state.py index 743a4f9217..4b3071acce 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -120,7 +120,7 @@ class StateFilter: def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: """The inverse to `from_types`.""" - for (event_type, state_keys) in self.types.items(): + for event_type, state_keys in self.types.items(): if state_keys is None: yield event_type, None else: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 9387632d0d..6ffa56217e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -98,7 +98,6 @@ class EvictionReason(Enum): @attr.s(slots=True, auto_attribs=True) class CacheMetric: - _cache: Sized _cache_type: str _cache_name: str diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 3b1e205700..1c0fde4966 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -183,7 +183,7 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled = [] errors = [] - for (requirement, must_be_installed) in dependencies: + for requirement, must_be_installed in dependencies: try: dist: metadata.Distribution = metadata.distribution(requirement.name) except metadata.PackageNotFoundError: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index f97f98a057..d00d34e652 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -211,7 +211,6 @@ def _check_yield_points( result = Failure() if current_context() != expected_context: - # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the # deferred we waited on didn't follow the rules, or we forgot to diff --git a/synmark/__main__.py b/synmark/__main__.py index 35a59e347a..19de639187 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py @@ -34,12 +34,10 @@ def make_test(main): """ def _main(loops): - reactor = make_reactor() file_out = StringIO() with redirect_stderr(file_out): - d = Deferred() d.addCallback(lambda _: ensureDeferred(main(reactor, loops))) diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index 9419892e95..8beb077e0a 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -30,7 +30,6 @@ from synapse.util import Clock class LineCounter(LineOnlyReceiver): - delimiter = b"\n" def __init__(self, *args, **kwargs): diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index febcc1499d..e2a3bad065 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast +from typing import List, Optional, Sequence, Tuple, cast from unittest.mock import Mock from typing_extensions import TypeAlias from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ( ApplicationService, @@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock from ..utils import MockClock -if TYPE_CHECKING: - from twisted.internet.testing import MemoryReactor - class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 35dd9a20df..33af8770fd 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -24,7 +24,6 @@ from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, @@ -37,7 +36,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): return config def test_complexity_simple(self) -> None: - u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") @@ -71,7 +69,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(complexity, 1.23) def test_join_too_large(self) -> None: - u1 = self.register_user("u1", "pass") handler = self.hs.get_room_member_handler() @@ -131,7 +128,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_join_too_large_once_joined(self) -> None: - u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index bba6469b55..6c7738d810 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -34,7 +34,6 @@ from tests.unittest import override_config class FederationServerTests(unittest.FederatingHomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 6f300b8e11..5569ccef8a 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, JoinRules from synapse.api.room_versions import RoomVersions from synapse.rest.client import knock, login, room from synapse.server import HomeServer +from synapse.types import UserID from synapse.util import Clock from tests import unittest @@ -296,3 +297,58 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(args[0][0]["user_agent"], "user_agent") self.assertGreater(args[0][0]["last_seen"], 0) self.assertNotIn("access_token", args[0][0]) + + def test_account_data(self) -> None: + """Tests that user account data get exported.""" + # add account data + self.get_success( + self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1}) + ) + self.get_success( + self._store.add_account_data_to_room( + self.user2, "test_room", "m.per_room", {"b": 2} + ) + ) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + # two calls, one call for user data and one call for room data + writer.write_account_data.assert_called() + + args = writer.write_account_data.call_args_list[0][0] + self.assertEqual(args[0], "global") + self.assertEqual(args[1]["m.global"]["a"], 1) + + args = writer.write_account_data.call_args_list[1][0] + self.assertEqual(args[0], "test_room") + self.assertEqual(args[1]["m.per_room"]["b"], 2) + + def test_media_ids(self) -> None: + """Tests that media's metadata get exported.""" + + self.get_success( + self._store.store_local_media( + media_id="media_1", + media_type="image/png", + time_now_ms=self.clock.time_msec(), + upload_name=None, + media_length=50, + user_id=UserID.from_string(self.user2), + ) + ) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + writer.write_media_id.assert_called_once() + + args = writer.write_media_id.call_args[0] + self.assertEqual(args[0], "media_1") + self.assertEqual(args[1]["media_id"], "media_1") + self.assertEqual(args[1]["media_length"], 50) + self.assertGreater(args[1]["created_ts"], 0) + self.assertIsNone(args[1]["upload_name"]) + self.assertIsNone(args[1]["last_access_ts"]) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 179c96adf2..2e838e6572 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -81,7 +81,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def _create_duplicate_event( self, txn_id: str - ) -> Tuple[EventBase, EventContext, Optional[dict]]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. """ @@ -109,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_suitably_random" - event1, context, _ = self._create_duplicate_event(txn_id) + event1, unpersisted_context, _ = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event1)) ret_event1 = self.get_success( self.handler.handle_new_client_event( @@ -121,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(event1.event_id, ret_event1.event_id) - event2, context, _ = self._create_duplicate_event(txn_id) + event2, unpersisted_context, _ = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event2)) # We want to test that the deduplication at the persit event end works, # so we want to make sure we test with different events. @@ -142,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_event` directly also does the right # thing. - event3, context, _ = self._create_duplicate_event(txn_id) + event3, unpersisted_context, _ = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event3)) + self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( @@ -156,7 +160,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_events` directly also does the right # thing. - event4, context, _ = self._create_duplicate_event(txn_id) + event4, unpersisted_context, _ = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event4)) + self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( @@ -176,8 +182,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_else_suitably_random" # Create two duplicate events to persist at the same time - event1, context1, _ = self._create_duplicate_event(txn_id) - event2, context2, _ = self._create_duplicate_event(txn_id) + event1, unpersisted_context1, _ = self._create_duplicate_event(txn_id) + context1 = self.get_success(unpersisted_context1.persist(event1)) + event2, unpersisted_context2, _ = self._create_duplicate_event(txn_id) + context2 = self.get_success(unpersisted_context2.persist(event2)) # Ensure their event IDs are different to start with self.assertNotEqual(event1.event_id, event2.event_id) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index ad42a7183e..161ff0a6c1 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -62,7 +62,7 @@ class TestSpamChecker: request_info: Collection[Tuple[str, str]], auth_provider_id: Optional[str], ) -> RegistrationBehaviour: - pass + return RegistrationBehaviour.ALLOW class DenyAll(TestSpamChecker): @@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker: username: Optional[str], request_info: Collection[Tuple[str, str]], ) -> RegistrationBehaviour: - pass + return RegistrationBehaviour.ALLOW class LegacyAllowAll(TestLegacyRegistrationSpamChecker): @@ -507,7 +507,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Lower the permissions of the inviter. event_creation_handler = self.hs.get_event_creation_handler() requester = create_requester(inviter) - event, context, _ = self.get_success( + + event, unpersisted_context, _ = self.get_success( event_creation_handler.create_event( requester, { @@ -519,6 +520,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_creation_handler.handle_new_client_event( requester, events_and_context=[(event, context)] diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py index 137deab138..d6f43a98fc 100644 --- a/tests/handlers/test_sso.py +++ b/tests/handlers/test_sso.py @@ -113,7 +113,6 @@ async def mock_get_file( headers: Optional[RawHeaders] = None, is_allowed_content_type: Optional[Callable[[str], bool]] = None, ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: - fake_response = FakeResponse(code=404) if url == "http://my.server/me.png": fake_response = FakeResponse( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index f1a50c5bcb..d11ded6c5b 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 class StatsRoomTests(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9bb9c1afc6..c5746005b5 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -192,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, self.appservice.sender, tok=self.appservice.token) self._check_only_one_user_in_directory(user, room) + def test_search_term_with_colon_in_it_does_not_raise(self) -> None: + """ + Regression test: Test that search terms with colons in them are acceptable. + """ + u1 = self.register_user("user1", "pass") + self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10)) + def test_user_not_in_users_table(self) -> None: """Unclear how it happens, but on matrix.org we've seen join events for users who aren't in the users table. Test that we don't fall over diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index acfdcd3bca..eb7f53fee5 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -30,7 +30,7 @@ from twisted.internet.interfaces import ( IOpenSSLClientConnectionCreator, IProtocolFactory, ) -from twisted.internet.protocol import Factory +from twisted.internet.protocol import Factory, Protocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent @@ -63,7 +63,7 @@ from tests.http import ( get_test_ca_cert_file, ) from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.utils import default_config +from tests.utils import checked_cast, default_config logger = logging.getLogger(__name__) @@ -146,8 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase): # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(dummy_address) - assert isinstance(client_protocol, _WrappingProtocol) + # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91) + client_protocol = checked_cast( + _WrappingProtocol, client_factory.buildProtocol(dummy_address) + ) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -446,7 +448,6 @@ class MatrixFederationAgentTests(unittest.TestCase): server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() ).buildProtocol(dummy_address) - assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport @@ -465,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase): else: assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other - c2s_transport = client_protocol.transport + assert isinstance(client_protocol, Protocol) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) @@ -1529,7 +1531,7 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None: def _wrap_server_factory_for_tls( factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None -) -> IProtocolFactory: +) -> TLSMemoryBIOFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate signed by our test CA, valid for the domains in `sanlist` diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 7748f56ee6..6ab13357f9 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -46,7 +46,6 @@ class SrvResolverTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup() -> Generator["Deferred[object]", object, List[Server]]: - with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) result: List[Server] diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 9cfe1ad0de..f6d6684985 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -149,7 +149,7 @@ class BlacklistingAgentTest(TestCase): self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" # Configure the reactor's DNS resolver. - for (domain, ip) in ( + for domain, ip in ( (self.safe_domain, self.safe_ip), (self.unsafe_domain, self.unsafe_ip), (self.allowed_domain, self.allowed_ip), diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index a817940730..cc175052ac 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -28,7 +28,7 @@ from twisted.internet.endpoints import ( _WrappingProtocol, ) from twisted.internet.interfaces import IProtocol, IProtocolFactory -from twisted.internet.protocol import Factory +from twisted.internet.protocol import Factory, Protocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel @@ -43,6 +43,7 @@ from tests.http import ( ) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase +from tests.utils import checked_cast logger = logging.getLogger(__name__) @@ -620,7 +621,6 @@ class MatrixFederationAgentTests(TestCase): server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() ).buildProtocol(dummy_address) - assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport @@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase): else: assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other - c2s_transport = client_protocol.transport + assert isinstance(client_protocol, Protocol) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) @@ -757,12 +758,14 @@ class MatrixFederationAgentTests(TestCase): assert isinstance(proxy_server, HTTPChannel) # fish the transports back out so that we can do the old switcheroo - s2c_transport = proxy_server.transport - assert isinstance(s2c_transport, FakeTransport) - client_protocol = s2c_transport.other - assert isinstance(client_protocol, _WrappingProtocol) - c2s_transport = client_protocol.transport - assert isinstance(c2s_transport, FakeTransport) + # To help mypy out with the various Protocols and wrappers and mocks, we do + # some explicit casting. Without the casts, we hit the bug I reported at + # https://github.com/Shoobx/mypy-zope/issues/91 . + # We also double-checked these casts at runtime (test-time) because I found it + # quite confusing to deduce these types in the first place! + s2c_transport = checked_cast(FakeTransport, proxy_server.transport) + client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -822,9 +825,9 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) def test_proxy_with_no_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) - self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") - self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._hostStr, "proxy.com") + self.assertEqual(proxy_ep._port, 8888) @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) def test_proxy_with_unsupported_scheme(self) -> None: @@ -834,25 +837,21 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) def test_proxy_with_http_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) - self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") - self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._hostStr, "proxy.com") + self.assertEqual(proxy_ep._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) def test_proxy_with_https_scheme(self) -> None: https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) - self.assertEqual( - https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" - ) - self.assertEqual( - https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888 - ) + proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com") + self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888) def _wrap_server_factory_for_tls( factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None -) -> IProtocolFactory: +) -> TLSMemoryBIOFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index c08954d887..5191e31a8a 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler from tests.logging import LoggerCleanupMixin from tests.server import FakeTransport, get_clock from tests.unittest import TestCase +from tests.utils import checked_cast def connect_logging_client( @@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): client, server = connect_logging_client(self.reactor, 0) # Trigger data being sent - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # One log message, with a single trailing newline logs = server.data.decode("utf8").splitlines() @@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # Only the 7 infos made it through, the debugs were elided logs = server.data.splitlines() @@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # The 10 warnings made it through, the debugs and infos were elided logs = server.data.splitlines() @@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # The first five and last five warnings made it through, the debugs and # infos were elided diff --git a/tests/rest/media/v1/__init__.py b/tests/media/__init__.py index b1ee10cfcc..68910cbf5b 100644 --- a/tests/rest/media/v1/__init__.py +++ b/tests/media/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2018 New Vector Ltd +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/media/v1/test_base.py b/tests/media/test_base.py index c73179151a..66498c744d 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/media/test_base.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.media._base import get_filename_from_headers from tests import unittest diff --git a/tests/rest/media/v1/test_filepath.py b/tests/media/test_filepath.py index 43e6f0f70a..95e3b83d5a 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/media/test_filepath.py @@ -15,7 +15,7 @@ import inspect import os from typing import Iterable -from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check +from synapse.media.filepath import MediaFilePaths, _wrap_with_jail_check from tests import unittest diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/media/test_html_preview.py index 1062081a06..e7da75db3e 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/media/test_html_preview.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.media.v1.preview_html import ( +from synapse.media.preview_html import ( _get_html_media_encodings, decode_body, parse_html_to_open_graph, diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/media/test_media_storage.py index 17a3b06a8e..870047d0f2 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -34,13 +34,13 @@ from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable +from synapse.media._base import FileInfo +from synapse.media.filepath import MediaFilePaths +from synapse.media.media_storage import MediaStorage, ReadableFileWrapper +from synapse.media.storage_provider import FileStorageProviderBackend from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login -from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.filepath import MediaFilePaths -from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper -from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias from synapse.util import Clock @@ -52,7 +52,6 @@ from tests.utils import default_config class MediaStorageTests(unittest.HomeserverTestCase): - needs_threadpool = True def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -207,7 +206,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches: List[ Tuple[ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", @@ -255,7 +253,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): config["max_image_pixels"] = 2000000 provider_config = { - "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", + "module": "synapse.media.storage_provider.FileStorageProviderBackend", "store_local": True, "store_synchronous": False, "store_remote": True, @@ -268,7 +266,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] diff --git a/tests/rest/media/v1/test_oembed.py b/tests/media/test_oembed.py index 3f7f1dbab9..c8bf8421da 100644 --- a/tests/rest/media/v1/test_oembed.py +++ b/tests/media/test_oembed.py @@ -18,7 +18,7 @@ from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor -from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult +from synapse.media.oembed import OEmbedProvider, OEmbedResult from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 2fe879df47..af0341808d 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -33,7 +33,6 @@ from tests.unittest import HomeserverTestCase, override_config class TestBulkPushRuleEvaluator(HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -131,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Create a new message event, and try to evaluate it under the dodgy # power level event. - event, context, _ = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -146,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): prev_event_ids=[pl_event_id], ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise @@ -171,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): """Ensure that push rules are not calculated when disabled in the config""" # Create a new message event which should cause a notification. - event, context, _ = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -185,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Mock the method which calculates push rules -- we do this instead of @@ -201,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. - event, context, _ = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -212,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) - + context = self.get_success(unpersisted_context.persist(event)) # Execute the push rule machinery. self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -227,7 +228,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) return len(result) > 0 - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3952_intentional_mentions": True, + "msc3966_exact_event_property_contains": True, + } + } + ) def test_user_mentions(self) -> None: """Test the behavior of an event which includes invalid user mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -323,7 +331,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) ) - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3952_intentional_mentions": True, + "msc3966_exact_event_property_contains": True, + } + } + ) def test_room_mentions(self) -> None: """Test the behavior of an event which includes invalid room mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -377,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, context, _ = self.get_success( + event, unpersisted_context, _ = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -391,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self.event_creation_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7563f33fdc..4ea5472eb4 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -39,7 +39,6 @@ class _User: class EmailPusherTests(HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -48,7 +47,6 @@ class EmailPusherTests(HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["email"] = { "enable_notifs": True, @@ -371,10 +369,8 @@ class EmailPusherTests(HomeserverTestCase): # disassociate the user's email address self.get_success( - self.auth_handler.delete_threepid( - user_id=self.user_id, - medium="email", - address="a@example.com", + self.auth_handler.delete_local_threepid( + user_id=self.user_id, medium="email", address="a@example.com" ) ) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 0554d247bc..ff5a9a66f5 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Set, Union, cast +from typing import Any, Dict, List, Optional, Union, cast import frozendict @@ -54,7 +54,7 @@ class FlattenDictTestCase(unittest.TestCase): self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input)) self.assertEqual( {"m\\.foo.b\\\\ar": "abc"}, - _flatten_dict(input, msc3783_escape_event_match_key=True), + _flatten_dict(input, msc3873_escape_event_match_key=True), ) def test_non_string(self) -> None: @@ -147,9 +147,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): self, content: JsonMapping, *, - has_mentions: bool = False, - user_mentions: Optional[Set[str]] = None, - room_mention: bool = False, related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( @@ -168,9 +165,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluator( _flatten_dict(event), - has_mentions, - user_mentions or set(), - room_mention, + False, room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), @@ -178,7 +173,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): related_event_match_enabled=True, room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, - msc3758_exact_event_match=True, msc3966_exact_event_property_contains=True, ) @@ -206,53 +200,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): # A display name with spaces should work fine. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) - def test_user_mentions(self) -> None: - """Check for user mentions.""" - condition = {"kind": "org.matrix.msc3952.is_user_mention"} - - # No mentions shouldn't match. - evaluator = self._get_evaluator({}, has_mentions=True) - self.assertFalse(evaluator.matches(condition, "@user:test", None)) - - # An empty set shouldn't match - evaluator = self._get_evaluator({}, has_mentions=True, user_mentions=set()) - self.assertFalse(evaluator.matches(condition, "@user:test", None)) - - # The Matrix ID appearing anywhere in the mentions list should match - evaluator = self._get_evaluator( - {}, has_mentions=True, user_mentions={"@user:test"} - ) - self.assertTrue(evaluator.matches(condition, "@user:test", None)) - - evaluator = self._get_evaluator( - {}, has_mentions=True, user_mentions={"@another:test", "@user:test"} - ) - self.assertTrue(evaluator.matches(condition, "@user:test", None)) - - # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions - # since the BulkPushRuleEvaluator is what handles data sanitisation. - - def test_room_mentions(self) -> None: - """Check for room mentions.""" - condition = {"kind": "org.matrix.msc3952.is_room_mention"} - - # No room mention shouldn't match. - evaluator = self._get_evaluator({}, has_mentions=True) - self.assertFalse(evaluator.matches(condition, None, None)) - - # Room mention should match. - evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True) - self.assertTrue(evaluator.matches(condition, None, None)) - - # A room mention and user mention is valid. - evaluator = self._get_evaluator( - {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True - ) - self.assertTrue(evaluator.matches(condition, None, None)) - - # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions - # since the BulkPushRuleEvaluator is what handles data sanitisation. - def _assert_matches( self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: @@ -424,12 +371,39 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should not match before a newline", ) + def test_event_match_pattern(self) -> None: + """Check that event_match conditions do not use a "pattern_type" from user data.""" + + # The pattern_type should not be deserialized into anything valid. + condition = { + "kind": "event_match", + "key": "content.value", + "pattern_type": "user_id", + } + self._assert_not_matches( + condition, + {"value": "@user:test"}, + "should not be possible to pass a pattern_type in", + ) + + # This is an internal-only condition which shouldn't get deserialized. + condition = { + "kind": "event_match_type", + "key": "content.value", + "pattern_type": "user_id", + } + self._assert_not_matches( + condition, + {"value": "@user:test"}, + "should not be possible to pass a pattern_type in", + ) + def test_exact_event_match_string(self) -> None: """Check that exact_event_match conditions work as expected for strings.""" # Test against a string value. condition = { - "kind": "com.beeper.msc3758.exact_event_match", + "kind": "event_property_is", "key": "content.value", "value": "foobaz", } @@ -467,11 +441,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): """Check that exact_event_match conditions work as expected for booleans.""" # Test against a True boolean value. - condition = { - "kind": "com.beeper.msc3758.exact_event_match", - "key": "content.value", - "value": True, - } + condition = {"kind": "event_property_is", "key": "content.value", "value": True} self._assert_matches( condition, {"value": True}, @@ -491,7 +461,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): # Test against a False boolean value. condition = { - "kind": "com.beeper.msc3758.exact_event_match", + "kind": "event_property_is", "key": "content.value", "value": False, } @@ -516,11 +486,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): def test_exact_event_match_null(self) -> None: """Check that exact_event_match conditions work as expected for null.""" - condition = { - "kind": "com.beeper.msc3758.exact_event_match", - "key": "content.value", - "value": None, - } + condition = {"kind": "event_property_is", "key": "content.value", "value": None} self._assert_matches( condition, {"value": None}, @@ -536,11 +502,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): def test_exact_event_match_integer(self) -> None: """Check that exact_event_match conditions work as expected for integers.""" - condition = { - "kind": "com.beeper.msc3758.exact_event_match", - "key": "content.value", - "value": 1, - } + condition = {"kind": "event_property_is", "key": "content.value", "value": 1} self._assert_matches( condition, {"value": 1}, diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index ddca9d696c..57c781a0c3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -64,7 +64,6 @@ def patch__eq__(cls: object) -> Callable[[], None]: class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): - STORE_TYPE = EventsWorkerStore def setUp(self) -> None: diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index bf927beb6a..bab77b2df7 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -141,3 +141,64 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): self.get_success(ctx_worker1.__aexit__(None, None, None)) self.assertTrue(d.called) + + def test_wait_for_stream_position_rdata(self) -> None: + """Check that wait for stream position correctly waits for an update + from the correct instance, when RDATA is sent. + """ + store = self.hs.get_datastores().main + cmd_handler = self.hs.get_replication_command_handler() + data_handler = self.hs.get_replication_data_handler() + + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + + cache_id_gen = worker1.get_datastores().main._cache_id_gen + assert cache_id_gen is not None + + self.replicate() + + # First, make sure the master knows that `worker1` exists. + initial_token = cache_id_gen.get_current_token() + cmd_handler.send_command( + PositionCommand("caches", "worker1", initial_token, initial_token) + ) + self.replicate() + + # `wait_for_stream_position` should only return once master receives a + # notification that `next_token2` has persisted. + ctx_worker1 = cache_id_gen.get_next_mult(2) + next_token1, next_token2 = self.get_success(ctx_worker1.__aenter__()) + + d = defer.ensureDeferred( + data_handler.wait_for_stream_position("worker1", "caches", next_token2) + ) + self.assertFalse(d.called) + + # Insert an entry into the cache stream with token `next_token1`, but + # not `next_token2`. + self.get_success( + store.db_pool.simple_insert( + table="cache_invalidation_stream_by_instance", + values={ + "stream_id": next_token1, + "instance_name": "worker1", + "cache_func": "foo", + "keys": [], + "invalidation_ts": 0, + }, + ) + ) + + # Finish the context manager, triggering the data to be sent to master. + self.get_success(ctx_worker1.__aexit__(None, None, None)) + + # Master should get told about `next_token2`, so the deferred should + # resolve. + self.assertTrue(d.called) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 03f2112b07..aaa488bced 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -28,7 +28,6 @@ from tests import unittest class DeviceRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): class DevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 233eba3516..f189b07769 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -78,7 +78,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ Try to get an event report without authentication. """ - channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, {}) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -473,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ Try to get event report without authentication. """ - channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, {}) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -599,3 +599,142 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", content["event_json"]) self.assertIn("sender", content["event_json"]) self.assertIn("content", content["event_json"]) + + +class DeleteEventReportTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # create report + event_id = self.get_success( + self._store.add_event_report( + "room_id", + "event_id", + self.other_user, + "this makes me sad", + {}, + self.clock.time_msec(), + ) + ) + + self.url = f"/_synapse/admin/v1/event_reports/{event_id}" + + def test_no_auth(self) -> None: + """ + Try to delete event report without authentication. + """ + channel = self.make_request("DELETE", self.url) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.other_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_success(self) -> None: + """ + Testing delete a report. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + # check that report was deleted + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_invalid_report_id(self) -> None: + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self) -> None: + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("Event report not found", channel.json_body["error"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index db77a45ae3..6d04911d67 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.media.filepath import MediaFilePaths from synapse.rest.client import login, profile, room -from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -594,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -724,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -821,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 453a6e979c..9dbb778679 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, room.register_servlets, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index f71ff46d87..28b999573e 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class ServerNoticeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 87efab59bb..c278f6bbad 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -28,8 +28,8 @@ import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions +from synapse.media.filepath import MediaFilePaths from synapse.rest.client import devices, login, logout, profile, register, room, sync -from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e2ee1a1766..2b05dffc7d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): class DeactivateTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase): class WhoamiTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 208ec44829..0d8fe77b88 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -34,7 +34,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER -from tests.server import FakeChannel, make_request +from tests.server import FakeChannel from tests.unittest import override_config, skip_unless @@ -43,13 +43,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): super().__init__(hs) self.recaptcha_attempts: List[Tuple[dict, str]] = [] + def is_enabled(self) -> bool: + return True + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) class FallbackAuthTests(unittest.HomeserverTestCase): - servlets = [ auth.register_servlets, register.register_servlets, @@ -57,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = True @@ -1319,16 +1320,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): channel = self.submit_logout_token(logout_token) self.assertEqual(channel.code, 200) - # Now try to exchange the login token - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - # It should have failed - self.assertEqual(channel.code, 403) + # Now try to exchange the login token, it should fail. + self.helper.login_via_token(login_token, 403) @override_config( { diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index d1751e1557..c16e8d43f4 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -26,7 +26,6 @@ from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index b1ca81a911..bb845179d3 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["form_secret"] = "123abc" diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index 7a88aa2cda..6490e883bf 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class DirectoryTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, directory.register_servlets, diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 9fa1f82dfe..f31ebc8021 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -26,7 +26,6 @@ from tests import unittest class EphemeralMessageTestCase(unittest.HomeserverTestCase): - user_id = "@user:test" servlets = [ diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a9b7db9db2..54df2a252c 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = False config["enable_registration"] = True @@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") @@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 830762fd53..91678abf13 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" hijack_auth = True EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index 741fecea77..8ee5489057 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -14,12 +14,21 @@ from http import HTTPStatus +from signedjson.key import ( + encode_verify_key_base64, + generate_signing_key, + get_verify_key, +) +from signedjson.sign import sign_json + from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import keys, login +from synapse.types import JsonDict from tests import unittest from tests.http.server._base import make_request_with_cancellation_test +from tests.unittest import override_config class KeyQueryTestCase(unittest.HomeserverTestCase): @@ -118,3 +127,135 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertIn(bob, channel.json_body["device_keys"]) + + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: + # We only generate a master key to simplify the test. + master_signing_key = generate_signing_key(device_id) + master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key)) + + return { + "master_key": sign_json( + { + "user_id": user_id, + "usage": ["master"], + "keys": {"ed25519:" + master_verify_key: master_verify_key}, + }, + user_id, + master_signing_key, + ), + } + + def test_device_signing_with_uia(self) -> None: + """Device signing key upload requires UIA.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + content = self.make_device_keys(alice_id, device_id) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # add UI auth + content["auth"] = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": alice_id}, + "password": password, + "session": session, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + @override_config({"ui_auth": {"session_timeout": "15m"}}) + def test_device_signing_with_uia_session_timeout(self) -> None: + """Device signing key upload requires UIA buy passes with grace period.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + content = self.make_device_keys(alice_id, device_id) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + @override_config( + { + "experimental_features": {"msc3967_enabled": True}, + "ui_auth": {"session_timeout": "15s"}, + } + ) + def test_device_signing_with_msc3967(self) -> None: + """Device signing key follows MSC3967 behaviour when enabled.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + keys1 = self.make_device_keys(alice_id, device_id) + + # Initial request should succeed as no existing keys are present. + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys1, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + keys2 = self.make_device_keys(alice_id, device_id) + + # Subsequent request should require UIA as keys already exist even though session_timeout is set. + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys2, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # add UI auth + keys2["auth"] = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": alice_id}, + "password": password, + "session": session, + } + + # Request should complete + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys2, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index ff5baa9f0a..62acf4f44e 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [ class LoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, ] diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index 6aedc1a11c..b8187db982 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, admin.register_servlets, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 67e16880e6..dcbb125a3b 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -35,7 +35,6 @@ class PresenceTestCase(unittest.HomeserverTestCase): servlets = [presence.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.presence_handler = Mock(spec=PresenceHandler) self.presence_handler.set_state.return_value = make_awaitable(None) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 8de5a342ae..27c93ad761 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -30,7 +30,6 @@ from tests import unittest class ProfileTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 4c561f9525..b228dba861 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register.register_servlets, @@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): class AccountValidityTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c8a6911d5e..fbbbcb23f1 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -30,7 +30,6 @@ from tests import unittest from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event -from tests.unittest import override_config class BaseRelationsTestCase(unittest.HomeserverTestCase): @@ -403,7 +402,7 @@ class RelationsTestCase(BaseRelationsTestCase): def test_edit(self) -> None: """Test that a simple edit works.""" - + orig_body = {"body": "Hi!", "msgtype": "m.text"} new_body = {"msgtype": "m.text", "body": "I've been edited!"} edit_event_content = { "msgtype": "m.text", @@ -424,9 +423,7 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"} - ) + self.assertEqual(channel.json_body["content"], orig_body) self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) # Request the room messages. @@ -443,7 +440,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) # Request the room context. - # /context should return the edited event. + # /context should return the event. channel = self.make_request( "GET", f"/rooms/{self.room}/context/{self.parent_id}", @@ -453,7 +450,7 @@ class RelationsTestCase(BaseRelationsTestCase): self._assert_edit_bundle( channel.json_body["event"], edit_event_id, edit_event_content ) - self.assertEqual(channel.json_body["event"]["content"], new_body) + self.assertEqual(channel.json_body["event"]["content"], orig_body) # Request sync, but limit the timeline so it becomes limited (and includes # bundled aggregations). @@ -491,45 +488,11 @@ class RelationsTestCase(BaseRelationsTestCase): edit_event_content, ) - @override_config({"experimental_features": {"msc3925_inhibit_edit": True}}) - def test_edit_inhibit_replace(self) -> None: - """ - If msc3925_inhibit_edit is enabled, then the original event should not be - replaced. - """ - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - edit_event_content = { - "msgtype": "m.text", - "body": "foo", - "m.new_content": new_body, - } - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content=edit_event_content, - ) - edit_event_id = channel.json_body["event_id"] - - # /context should return the *original* event. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"} - ) - self._assert_edit_bundle( - channel.json_body["event"], edit_event_id, edit_event_content - ) - def test_multi_edit(self) -> None: """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. """ - + orig_body = orig_body = {"body": "Hi!", "msgtype": "m.text"} self._send_relation( RelationTypes.REPLACE, "m.room.message", @@ -570,7 +533,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["event"]["content"], new_body) + self.assertEqual(channel.json_body["event"]["content"], orig_body) self._assert_edit_bundle( channel.json_body["event"], edit_event_id, edit_event_content ) @@ -642,6 +605,7 @@ class RelationsTestCase(BaseRelationsTestCase): def test_edit_edit(self) -> None: """Test that an edit cannot be edited.""" + orig_body = {"body": "Hi!", "msgtype": "m.text"} new_body = {"msgtype": "m.text", "body": "Initial edit"} edit_event_content = { "msgtype": "m.text", @@ -675,14 +639,12 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"} - ) + self.assertEqual(channel.json_body["content"], orig_body) # The relations information should not include the edit to the edit. self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content) - # /context should return the event updated for the *first* edit + # /context should return the bundled edit for the *first* edit # (The edit to the edit should be ignored.) channel = self.make_request( "GET", @@ -690,7 +652,7 @@ class RelationsTestCase(BaseRelationsTestCase): access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["event"]["content"], new_body) + self.assertEqual(channel.json_body["event"]["content"], orig_body) self._assert_edit_bundle( channel.json_body["event"], edit_event_id, edit_event_content ) @@ -1080,48 +1042,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_annotation(self) -> None: - """ - Test that annotations get correctly bundled. - """ - # Setup by sending a variety of relations. - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - def assert_annotations(bundled_aggregations: JsonDict) -> None: - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - bundled_aggregations, - ) - - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) - - def test_annotation_to_annotation(self) -> None: - """Any relation to an annotation should be ignored.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - event_id = channel.json_body["event_id"] - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id - ) - - # Fetch the initial annotation event to see if it has bundled aggregations. - channel = self.make_request( - "GET", - f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - # The first annotationt should not have any bundled aggregations. - self.assertNotIn("m.relations", channel.json_body["unsigned"]) - def test_reference(self) -> None: """ Test that references get correctly bundled. @@ -1138,7 +1058,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) def test_thread(self) -> None: """ @@ -1183,7 +1103,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1208,9 +1128,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_2 = channel.json_body["event_id"] - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2 + channel = self._send_relation( + RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2 ) + reference_event_id = channel.json_body["event_id"] def assert_thread(bundled_aggregations: JsonDict) -> None: self.assertEqual(2, bundled_aggregations.get("count")) @@ -1235,17 +1156,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assert_dict( { "m.relations": { - RelationTypes.ANNOTATION: { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 1}, - ] + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reference_event_id}] }, } }, bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6) def test_nested_thread(self) -> None: """ @@ -1330,7 +1249,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): thread_summary = relations_dict[RelationTypes.THREAD] self.assertIn("latest_event", thread_summary) latest_event_in_thread = thread_summary["latest_event"] - self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") # The latest event in the thread should have the edit appear under the # bundled aggregations. self.assertDictContainsSubset( @@ -1363,10 +1281,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_id = channel.json_body["event_id"] - # Annotate the thread. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + # Make a reference to the thread. + channel = self._send_relation( + RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id ) + reference_event_id = channel.json_body["event_id"] channel = self.make_request( "GET", @@ -1377,9 +1296,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual( channel.json_body["unsigned"].get("m.relations"), { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, }, ) @@ -1396,9 +1313,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual( thread_message["unsigned"].get("m.relations"), { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, }, ) @@ -1410,7 +1325,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): Note that the spec allows for a server to return additional fields beyond what is specified. """ - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test") + reference_event_id = channel.json_body["event_id"] # Note that the sync filter does not include "unsigned" as a field. filter = urllib.parse.quote_plus( @@ -1428,7 +1344,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Ensure there's bundled aggregations on it. self.assertIn("unsigned", parent_event) - self.assertIn("m.relations", parent_event["unsigned"]) + self.assertEqual( + parent_event["unsigned"].get("m.relations"), + { + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, + }, + ) class RelationIgnoredUserTestCase(BaseRelationsTestCase): @@ -1475,53 +1396,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): return before_aggregations[relation_type], after_aggregations[relation_type] - def test_annotation(self) -> None: - """Annotations should ignore""" - # Send 2 from us, 2 from the to be ignored user. - allowed_event_ids = [] - ignored_event_ids = [] - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - allowed_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b") - allowed_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="a", - access_token=self.user2_token, - ) - ignored_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="c", - access_token=self.user2_token, - ) - ignored_event_ids.append(channel.json_body["event_id"]) - - before_aggregations, after_aggregations = self._test_ignored_user( - RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids - ) - - self.assertCountEqual( - before_aggregations["chunk"], - [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - {"type": "m.reaction", "key": "c", "count": 1}, - ], - ) - - self.assertCountEqual( - after_aggregations["chunk"], - [ - {"type": "m.reaction", "key": "a", "count": 1}, - {"type": "m.reaction", "key": "b", "count": 1}, - ], - ) - def test_reference(self) -> None: - """Annotations should ignore""" + """Aggregations should exclude reference relations from ignored users""" channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") allowed_event_ids = [channel.json_body["event_id"]] @@ -1544,7 +1420,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): ) def test_thread(self) -> None: - """Annotations should ignore""" + """Aggregations should exclude thread releations from ignored users""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") allowed_event_ids = [channel.json_body["event_id"]] @@ -1618,43 +1494,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): for t in threads ] - def test_redact_relation_annotation(self) -> None: - """ - Test that annotations of an event are properly handled after the - annotation is redacted. - - The redacted relation should not be included in bundled aggregations or - the response to relations. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - unredacted_event_id = channel.json_body["event_id"] - - # Both relations should exist. - event_ids = self._get_related_events() - relations = self._get_bundled_aggregations() - self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) - self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, - ) - - # Redact one of the reactions. - self._redact(to_redact_event_id) - - # The unredacted relation should still exist. - event_ids = self._get_related_events() - relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) - self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1775,14 +1614,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase): is redacted. """ # Add a relation - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test") related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids = self._get_related_events() relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) - self.assertIn(RelationTypes.ANNOTATION, relations) + self.assertIn(RelationTypes.REFERENCE, relations) # Redact the original event. self._redact(self.parent_id) @@ -1792,8 +1631,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [related_event_id]) self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + relations[RelationTypes.REFERENCE], + {"chunk": [{"event_id": related_event_id}]}, ) def test_redact_parent_thread(self) -> None: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index c0eb5d01a6..8dbd64be55 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" class RendezvousServletTestCase(unittest.HomeserverTestCase): - servlets = [ rendezvous.register_servlets, ] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index cfad182b2f..a4900703c4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase): servlets = [room.register_servlets, room.register_deprecated_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.hs = self.setup_test_homeserver( "red", federation_http_client=None, @@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase): rmcreator_id = "@notme:red" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -715,7 +713,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(30, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -728,7 +726,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id @@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase): class RoomJoinTestCase(RoomBase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): hijack_auth = False def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # Register the user who does the searching self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") @@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() @@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["allow_public_rooms_without_auth"] = True self.hs = self.setup_test_homeserver(config=config) @@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase): class ContextTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): class ThreepidInviteTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -3438,7 +3427,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): """ Test allowing/blocking threepid invites with a spam-check module. - In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index b9047194dd..9c876c7a32 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -41,7 +41,6 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): class ExcludeRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 3277a116e8..7245830b01 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ + # patch the rules module with a Mock which will return False for some event # types async def check( @@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_modify_event(self) -> None: """The module can return a modified version of the event""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] @@ -315,6 +317,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] @@ -465,7 +468,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def test_fn( event: EventBase, state_events: StateMap[EventBase] ) -> Tuple[bool, Optional[JsonDict]]: - if event.is_state and event.type == EventTypes.PowerLevels: + if event.is_state() and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { "room_id": event.room_id, @@ -971,3 +974,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Check that the mock was called with the right parameters self.assertEqual(args, (user_id, "email", "foo@example.com")) + + def test_on_add_and_remove_user_third_party_identifier(self) -> None: + """Tests that the on_add_user_third_party_identifier and + on_remove_user_third_party_identifier module callbacks are called + just before associating and removing a 3PID to/from an account. + """ + # Pretend to be a Synapse module and register both callbacks as mocks. + third_party_rules = self.hs.get_third_party_event_rules() + on_add_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_add_user_third_party_identifier_callback_mock + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_remove_user_third_party_identifier_callback_mock + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + + # Check that the mocked add callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_add_user_third_party_identifier_callback_mock.assert_called_once() + args = on_add_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) + + # Now remove the 3PID from the user + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [], + }, + access_token=admin_tok, + ) + + # Check that the mocked remove callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_remove_user_third_party_identifier_callback_mock.assert_called_once() + args = on_remove_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) + + def test_on_remove_user_third_party_identifier_is_called_on_deactivate( + self, + ) -> None: + """Tests that the on_remove_user_third_party_identifier module callback is called + when a user is deactivated and their third-party ID associations are deleted. + """ + # Pretend to be a Synapse module and register both callbacks as mocks. + third_party_rules = self.hs.get_third_party_event_rules() + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_remove_user_third_party_identifier_callback_mock + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Now deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "deactivated": True, + }, + access_token=admin_tok, + ) + + # Check that the mocked remove callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_remove_user_third_party_identifier_callback_mock.assert_called_once() + args = on_remove_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 8d6f2b6ff9..9532e5ddc1 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -36,6 +36,7 @@ from urllib.parse import urlencode import attr from typing_extensions import Literal +from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.resource import Resource from twisted.web.server import Site @@ -67,6 +68,7 @@ class RestHelper: """ hs: HomeServer + reactor: MemoryReactorClock site: Site auth_user_id: Optional[str] @@ -142,7 +144,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -216,7 +218,7 @@ class RestHelper: data["reason"] = reason channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -313,7 +315,7 @@ class RestHelper: data.update(extra_data or {}) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -394,7 +396,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -433,7 +435,7 @@ class RestHelper: path = path + f"?access_token={tok}" channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", path, @@ -488,7 +490,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + channel = make_request(self.reactor, self.site, method, path, content) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, @@ -573,8 +575,8 @@ class RestHelper: image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( - self.hs.get_reactor(), - FakeSite(resource, self.hs.get_reactor()), + self.reactor, + FakeSite(resource, self.reactor), "POST", path, content=image_data, @@ -603,7 +605,7 @@ class RestHelper: expect_code: The return code to expect from attempting the whoami request """ channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", "account/whoami", @@ -642,7 +644,7 @@ class RestHelper: ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC - Returns the result of the final token login. + Returns the result of the final token login and the fake authorization grant. Requires that "oidc_config" in the homeserver config be set appropriately (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a @@ -672,10 +674,28 @@ class RestHelper: assert m, channel.text_body login_token = m.group(1) - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + return self.login_via_token(login_token, expected_status), grant + + def login_via_token( + self, + login_token: str, + expected_status: int = 200, + ) -> JsonDict: + """Submit the matrix login token to the login API, which gives us our + matrix access token and device id.Log in (as a new user) via OIDC + + Returns the result of the token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", "/login", @@ -684,7 +704,7 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body, grant + return channel.json_body def auth_via_oidc( self, @@ -805,7 +825,7 @@ class RestHelper: with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", callback_uri, @@ -849,7 +869,7 @@ class RestHelper: # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", uri, @@ -867,7 +887,7 @@ class RestHelper: location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", urllib.parse.urlunsplit(("", "") + parts[2:]), @@ -900,9 +920,7 @@ class RestHelper: + urllib.parse.urlencode({"session": ui_auth_session_id}) ) # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) + channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint) # that should serve a confirmation page assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index 23f227aed6..b59d9dfd4d 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -31,7 +31,6 @@ from tests.utils import MockClock class MediaRetentionTestCase(unittest.HomeserverTestCase): - ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/test_url_preview.py index 2c321f8d04..e91dc581c2 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py @@ -26,8 +26,8 @@ from twisted.internet.interfaces import IAddress, IResolutionReceiver from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from synapse.config.oembed import OEmbedEndpointConfig -from synapse.rest.media.v1.media_repository import MediaRepositoryResource -from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS +from synapse.rest.media.media_repository_resource import MediaRepositoryResource +from synapse.rest.media.preview_url_resource import IMAGE_CACHE_EXPIRY_MS from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["url_preview_enabled"] = True config["max_spider_size"] = 9999999 @@ -83,7 +82,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): config["media_store_path"] = self.media_store_path provider_config = { - "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", + "module": "synapse.media.storage_provider.FileStorageProviderBackend", "store_local": True, "store_synchronous": False, "store_remote": True, @@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] @@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): addressTypes: Optional[Sequence[Type[IAddress]]] = None, transportSemantics: str = "TCP", ) -> IResolutionReceiver: - resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) if hostName not in self.lookups: @@ -660,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """If the preview image doesn't exist, ensure some data is returned.""" self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - end_content = ( + result = ( b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>""" ) @@ -681,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) - % (len(end_content),) - + end_content + % (len(result),) + + result ) self.pump() @@ -691,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase): # The image should not be in the result. self.assertNotIn("og:image", channel.json_body) + def test_oembed_failure(self) -> None: + """If the autodiscovered oEmbed URL fails, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + result = b""" + <title>oEmbed Autodiscovery Fail</title> + <link rel="alternate" type="application/json+oembed" + href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json" + title="matrixdotorg" /> + """ + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail") + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. diff --git a/tests/server.py b/tests/server.py index 237bcad8ba..5de9722766 100644 --- a/tests/server.py +++ b/tests/server.py @@ -22,20 +22,25 @@ import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( + Any, + Awaitable, Callable, Dict, Iterable, List, MutableMapping, Optional, + Sequence, Tuple, Type, + TypeVar, Union, + cast, ) from unittest.mock import Mock import attr -from typing_extensions import Deque +from typing_extensions import Deque, ParamSpec from zope.interface import implementer from twisted.internet import address, threads, udp @@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, + IConnector, IConsumer, IHostnameResolver, + IProducer, IProtocol, IPullProducer, IPushProducer, @@ -54,6 +61,8 @@ from twisted.internet.interfaces import ( IResolverSimple, ITransport, ) +from twisted.internet.protocol import ClientFactory, DatagramProtocol +from twisted.python import threadpool from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers @@ -61,6 +70,7 @@ from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig +from synapse.config.homeserver import HomeServerConfig from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules @@ -88,6 +98,9 @@ from tests.utils import ( logger = logging.getLogger(__name__) +R = TypeVar("R") +P = ParamSpec("P") + # the type of thing that can be passed into `make_request` in the headers list CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] @@ -98,12 +111,14 @@ class TimedOutException(Exception): """ -@implementer(IConsumer) +@implementer(ITransport, IPushProducer, IConsumer) @attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). + + See twisted.web.http.HTTPChannel. """ site: Union[Site, "FakeSite"] @@ -142,7 +157,7 @@ class FakeChannel: Raises an exception if the request has not yet completed. """ - if not self.is_finished: + if not self.is_finished(): raise Exception("Request not yet completed") return self.result["body"].decode("utf8") @@ -165,27 +180,36 @@ class FakeChannel: h.addRawHeader(*i) return h - def writeHeaders(self, version, code, reason, headers): + def writeHeaders( + self, version: bytes, code: bytes, reason: bytes, headers: Headers + ) -> None: self.result["version"] = version self.result["code"] = code self.result["reason"] = reason self.result["headers"] = headers - def write(self, content: bytes) -> None: - assert isinstance(content, bytes), "Should be bytes! " + repr(content) + def write(self, data: bytes) -> None: + assert isinstance(data, bytes), "Should be bytes! " + repr(data) if "body" not in self.result: self.result["body"] = b"" - self.result["body"] += content + self.result["body"] += data + + def writeSequence(self, data: Iterable[bytes]) -> None: + for x in data: + self.write(x) + + def loseConnection(self) -> None: + self.unregisterProducer() + self.transport.loseConnection() # Type ignore: mypy doesn't like the fact that producer isn't an IProducer. - def registerProducer( # type: ignore[override] - self, - producer: Union[IPullProducer, IPushProducer], - streaming: bool, - ) -> None: - self._producer = producer + def registerProducer(self, producer: IProducer, streaming: bool) -> None: + # TODO This should ensure that the IProducer is an IPushProducer or + # IPullProducer, unfortunately twisted.protocols.basic.FileSender does + # implement those, but doesn't declare it. + self._producer = cast(Union[IPushProducer, IPullProducer], producer) self.producerStreaming = streaming def _produce() -> None: @@ -202,6 +226,16 @@ class FakeChannel: self._producer = None + def stopProducing(self) -> None: + if self._producer is not None: + self._producer.stopProducing() + + def pauseProducing(self) -> None: + raise NotImplementedError() + + def resumeProducing(self) -> None: + raise NotImplementedError() + def requestDone(self, _self: Request) -> None: self.result["done"] = True if isinstance(_self, SynapseRequest): @@ -281,12 +315,12 @@ class FakeSite: self.reactor = reactor self.experimental_cors_msc3886 = experimental_cors_msc3886 - def getResourceFor(self, request): + def getResourceFor(self, request: Request) -> IResource: return self._resource def make_request( - reactor, + reactor: MemoryReactorClock, site: Union[Site, FakeSite], method: Union[bytes, str], path: Union[bytes, str], @@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): A MemoryReactorClock that supports callFromThread. """ - def __init__(self): + def __init__(self) -> None: self.threadpool = ThreadPool(self) self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} - self._udp = [] + self._udp: List[udp.Port] = [] self.lookups: Dict[str, str] = {} - self._thread_callbacks: Deque[Callable[[], None]] = deque() + self._thread_callbacks: Deque[Callable[..., R]] = deque() lookups = self.lookups @implementer(IResolverSimple) class FakeResolver: - def getHostByName(self, name, timeout=None): + def getHostByName( + self, name: str, timeout: Optional[Sequence[int]] = None + ) -> "Deferred[str]": if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) @@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: raise NotImplementedError() - def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): + def listenUDP( + self, + port: int, + protocol: DatagramProtocol, + interface: str = "", + maxPacketSize: int = 8196, + ) -> udp.Port: p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() self._udp.append(p) return p - def callFromThread(self, callback, *args, **kwargs): + def callFromThread( + self, callable: Callable[..., Any], *args: object, **kwargs: object + ) -> None: """ Make the callback fire in the next reactor iteration. """ - cb = lambda: callback(*args, **kwargs) + cb = lambda: callable(*args, **kwargs) # it's not safe to call callLater() here, so we append the callback to a # separate queue. self._thread_callbacks.append(cb) - def getThreadPool(self): - return self.threadpool + def callInThread( + self, callable: Callable[..., Any], *args: object, **kwargs: object + ) -> None: + raise NotImplementedError() + + def suggestThreadPoolSize(self, size: int) -> None: + raise NotImplementedError() + + def getThreadPool(self) -> "threadpool.ThreadPool": + # Cast to match super-class. + return cast(threadpool.ThreadPool, self.threadpool) - def add_tcp_client_callback(self, host: str, port: int, callback: Callable): + def add_tcp_client_callback( + self, host: str, port: int, callback: Callable[[], None] + ) -> None: """Add a callback that will be invoked when we receive a connection attempt to the given IP/port using `connectTCP`. @@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ self._tcp_callbacks[(host, port)] = callback - def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): + def connectTCP( + self, + host: str, + port: int, + factory: ClientFactory, + timeout: float = 30, + bindAddress: Optional[Tuple[str, int]] = None, + ) -> IConnector: """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( @@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return conn - def advance(self, amount): + def advance(self, amount: float) -> None: # first advance our reactor's time, and run any "callLater" callbacks that # makes ready super().advance(amount) @@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): class ThreadPool: """ Threadless thread pool. + + See twisted.python.threadpool.ThreadPool """ - def __init__(self, reactor): + def __init__(self, reactor: IReactorTime): self._reactor = reactor - def start(self): + def start(self) -> None: pass - def stop(self): + def stop(self) -> None: pass - def callInThreadWithCallback(self, onResult, function, *args, **kwargs): - def _(res): + def callInThreadWithCallback( + self, + onResult: Callable[[bool, Union[Failure, R]], None], + function: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> "Deferred[None]": + def _(res: Any) -> None: if isinstance(res, Failure): onResult(False, res) else: onResult(True, res) - d = Deferred() + d: "Deferred[None]" = Deferred() d.addCallback(lambda x: function(*args, **kwargs)) d.addBoth(_) self._reactor.callLater(0, d.callback, True) @@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: for database in server.get_datastores().databases: pool = database._db_pool - def runWithConnection(func, *args, **kwargs): + def runWithConnection( + func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: return threads.deferToThreadPool( pool._reactor, pool.threadpool, @@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: **kwargs, ) - def runInteraction(interaction, *args, **kwargs): + def runInteraction( + desc: str, func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: return threads.deferToThreadPool( pool._reactor, pool.threadpool, pool._runInteraction, - interaction, + desc, + func, *args, **kwargs, ) - pool.runWithConnection = runWithConnection - pool.runInteraction = runInteraction + pool.runWithConnection = runWithConnection # type: ignore[assignment] + pool.runInteraction = runInteraction # type: ignore[assignment] # Replace the thread pool with a threadless 'thread' pool - pool.threadpool = ThreadPool(clock._reactor) + pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment] pool.running = True # We've just changed the Databases to run DB transactions on the same @@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: @implementer(ITransport) -@attr.s(cmp=False) +@attr.s(cmp=False, auto_attribs=True) class FakeTransport: """ A twisted.internet.interfaces.ITransport implementation which sends all its data @@ -588,48 +663,50 @@ class FakeTransport: If you want bidirectional communication, you'll need two instances. """ - other = attr.ib() + other: IProtocol """The Protocol object which will receive any data written to this transport. - - :type: twisted.internet.interfaces.IProtocol """ - _reactor = attr.ib() + _reactor: IReactorTime """Test reactor - - :type: twisted.internet.interfaces.IReactorTime """ - _protocol = attr.ib(default=None) + _protocol: Optional[IProtocol] = None """The Protocol which is producing data for this transport. Optional, but if set will get called back for connectionLost() notifications etc. """ - _peer_address: Optional[IAddress] = attr.ib(default=None) + _peer_address: IAddress = attr.Factory( + lambda: address.IPv4Address("TCP", "127.0.0.1", 5678) + ) """The value to be returned by getPeer""" - _host_address: Optional[IAddress] = attr.ib(default=None) + _host_address: IAddress = attr.Factory( + lambda: address.IPv4Address("TCP", "127.0.0.1", 1234) + ) """The value to be returned by getHost""" disconnecting = False disconnected = False connected = True - buffer = attr.ib(default=b"") - producer = attr.ib(default=None) - autoflush = attr.ib(default=True) + buffer: bytes = b"" + producer: Optional[IPushProducer] = None + autoflush: bool = True - def getPeer(self) -> Optional[IAddress]: + def getPeer(self) -> IAddress: return self._peer_address - def getHost(self) -> Optional[IAddress]: + def getHost(self) -> IAddress: return self._host_address - def loseConnection(self, reason=None): + def loseConnection(self) -> None: if not self.disconnecting: - logger.info("FakeTransport: loseConnection(%s)", reason) + logger.info("FakeTransport: loseConnection()") self.disconnecting = True if self._protocol: - self._protocol.connectionLost(reason) + self._protocol.connectionLost( + Failure(RuntimeError("FakeTransport.loseConnection()")) + ) # if we still have data to write, delay until that is done if self.buffer: @@ -640,38 +717,38 @@ class FakeTransport: self.connected = False self.disconnected = True - def abortConnection(self): + def abortConnection(self) -> None: logger.info("FakeTransport: abortConnection()") if not self.disconnecting: self.disconnecting = True if self._protocol: - self._protocol.connectionLost(None) + self._protocol.connectionLost(None) # type: ignore[arg-type] self.disconnected = True - def pauseProducing(self): + def pauseProducing(self) -> None: if not self.producer: return self.producer.pauseProducing() - def resumeProducing(self): + def resumeProducing(self) -> None: if not self.producer: return self.producer.resumeProducing() - def unregisterProducer(self): + def unregisterProducer(self) -> None: if not self.producer: return self.producer = None - def registerProducer(self, producer, streaming): + def registerProducer(self, producer: IPushProducer, streaming: bool) -> None: self.producer = producer self.producerStreaming = streaming - def _produce(): + def _produce() -> None: if not self.producer: # we've been unregistered return @@ -683,7 +760,7 @@ class FakeTransport: if not streaming: self._reactor.callLater(0.0, _produce) - def write(self, byt): + def write(self, byt: bytes) -> None: if self.disconnecting: raise Exception("Writing to disconnecting FakeTransport") @@ -695,11 +772,11 @@ class FakeTransport: if self.autoflush: self._reactor.callLater(0.0, self.flush) - def writeSequence(self, seq): + def writeSequence(self, seq: Iterable[bytes]) -> None: for x in seq: self.write(x) - def flush(self, maxbytes=None): + def flush(self, maxbytes: Optional[int] = None) -> None: if not self.buffer: # nothing to do. Don't write empty buffers: it upsets the # TLSMemoryBIOProtocol @@ -750,17 +827,17 @@ def connect_client( class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore[assignment] def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, + cleanup_func: Callable[[Callable[[], None]], None], + name: str = "test", + config: Optional[HomeServerConfig] = None, + reactor: Optional[ISynapseReactor] = None, homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): + **kwargs: Any, +) -> HomeServer: """ Setup a homeserver suitable for running tests against. Keyword arguments are passed to the Homeserver constructor. @@ -775,13 +852,14 @@ def setup_test_homeserver( HomeserverTestCase. """ if reactor is None: - from twisted.internet import reactor + from twisted.internet import reactor as _reactor + + reactor = cast(ISynapseReactor, _reactor) if config is None: config = default_config(name, parse=True) config.caches.resize_all_caches() - config.ldap_enabled = False if "clock" not in kwargs: kwargs["clock"] = MockClock() @@ -832,6 +910,8 @@ def setup_test_homeserver( # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() if isinstance(db_engine, PostgresEngine): + import psycopg2.extensions + db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER, @@ -839,6 +919,7 @@ def setup_test_homeserver( port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) + assert isinstance(db_conn, psycopg2.extensions.connection) db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) @@ -867,14 +948,15 @@ def setup_test_homeserver( hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] + database_pool = hs.get_datastores().databases[0] # We need to do cleanup on PostgreSQL - def cleanup(): + def cleanup() -> None: import psycopg2 + import psycopg2.extensions # Close all the db pools - database._db_pool.close() + database_pool._db_pool.close() dropped = False @@ -886,6 +968,7 @@ def setup_test_homeserver( port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) + assert isinstance(db_conn, psycopg2.extensions.connection) db_conn.autocommit = True cur = db_conn.cursor() @@ -918,23 +1001,23 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): + async def hash(p: str) -> str: return hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().hash = hash + hs.get_auth_handler().hash = hash # type: ignore[assignment] - async def validate_hash(p, h): + async def validate_hash(p: str, h: str) -> bool: return hashlib.md5(p.encode("utf8")).hexdigest() == h - hs.get_auth_handler().validate_hash = validate_hash + hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment] # Make the threadpool and database transactions synchronous for testing. _make_test_homeserver_synchronous(hs) # Load any configured modules into the homeserver module_api = hs.get_module_api() - for module, config in hs.config.modules.loaded_modules: - module(config=config, api=module_api) + for module, module_config in hs.config.modules.loaded_modules: + module(config=module_config, api=module_api) load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 6540ed53f1..3fdf5a6d52 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -25,7 +25,6 @@ from tests import unittest class ConsentNoticesTests(unittest.HomeserverTestCase): - servlets = [ sync.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -34,7 +33,6 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - tmpdir = self.mktemp() os.mkdir(tmpdir) self.consent_notice_message = "consent %(consent_uri)s" diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 373707b275..b6d5c474b0 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, devices.register_servlets, diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index ac77aec003..71db47405e 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -26,7 +26,6 @@ from tests.unittest import HomeserverTestCase class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, @@ -62,6 +61,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): keys and expected receipt key-values after duplicate receipts have been removed. """ + # First, undo the background update. def drop_receipts_unique_index(txn: LoggingTransaction) -> None: txn.execute(f"DROP INDEX IF EXISTS {index_name}") diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 3108ca3444..dbd8f3a85e 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -27,7 +27,6 @@ from tests.unittest import HomeserverTestCase class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 1bfd11ceae..b12691a9d3 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -140,3 +140,25 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # No one ignores the user now. self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) + + def test_ignoring_users_with_latest_stream_ids(self) -> None: + """Test that ignoring users updates the latest stream ID for the ignored + user list account data.""" + + def get_latest_ignore_streampos(user_id: str) -> Optional[int]: + return self.get_success( + self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( + user_id, AccountDataTypes.IGNORED_USER_LIST + ) + ) + + self.assertIsNone(get_latest_ignore_streampos("@user:test")) + + self._update_ignore_list("@other:test", "@another:remote") + + self.assertEqual(get_latest_ignore_streampos("@user:test"), 2) + + # Add one user, remove one user, and leave one user. + self._update_ignore_list("@foo:test", "@another:remote") + + self.assertEqual(get_latest_ignore_streampos("@user:test"), 3) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index d570684c99..7de109966d 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -43,8 +43,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = create_requester(self.user) - info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) - self.room_id = info["room_id"] + self.room_id, _, _ = self.get_success( + self.room_creator.create_room(self.requester, {}) + ) def run_background_update(self) -> None: """Re run the background update to clean up the extremities.""" @@ -275,10 +276,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = create_requester(self.user) - info, _ = self.get_success( + self.room_id, _, _ = self.get_success( self.room_creator.create_room(self.requester, {"visibility": "public"}) ) - self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.consent.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 7f7f4ef892..cd0079871c 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -656,7 +656,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 8d7310d8e5..2a9aa9e21c 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -417,7 +417,6 @@ class EventChainStoreTestCase(HomeserverTestCase): def fetch_chains( self, events: List[EventBase] ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: - # Fetch the map from event ID -> (chain ID, sequence number) rows = self.get_success( self.store.db_pool.simple_select_many_batch( @@ -492,7 +491,6 @@ class LinkMapTestCase(unittest.TestCase): class EventChainBackgroundUpdateTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, @@ -524,7 +522,8 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) - event, context, _ = self.get_success( + + event, unpersisted_context, _ = self.get_success( event_handler.create_event( self.requester, { @@ -537,6 +536,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] @@ -546,7 +546,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): assert state_ids1 is not None state1 = set(state_ids1.values()) - event, context, _ = self.get_success( + event, unpersisted_context, _ = self.get_success( event_handler.create_event( self.requester, { @@ -559,6 +559,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 8fc7936ab0..3e1984c15c 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -672,7 +672,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): complete_event_dict_map: Dict[str, JsonDict] = {} stream_ordering = 0 - for (event_id, prev_event_ids) in event_graph.items(): + for event_id, prev_event_ids in event_graph.items(): depth = depth_map[event_id] complete_event_dict_map[event_id] = { diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a91411168c..6897addbd3 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,8 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info, _ = self.get_success(room_creator.create_room(requester, {})) - room_id = info["room_id"] + room_id, _, _ = self.get_success(room_creator.create_room(requester, {})) last_event = None diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 76c06a9d1e..aa19c3bd30 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -774,7 +774,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual(r, 3) # add a bunch of dummy events to the events table - for (stream_ordering, ts) in ( + for stream_ordering, ts in ( (3, 110), (4, 120), (5, 120), diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index d8f42c5d05..857e2caf2e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase class PurgeTests(HomeserverTestCase): - user_id = "@red:server" servlets = [room.register_servlets] diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 12c17f1073..1b52eef23f 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -50,12 +50,14 @@ class ReceiptTestCase(HomeserverTestCase): self.otherRequester = create_requester(self.otherUser) # Create a test room - info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) - self.room_id1 = info["room_id"] + self.room_id1, _, _ = self.get_success( + self.room_creator.create_room(self.ourRequester, {}) + ) # Create a second test room - info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) - self.room_id2 = info["room_id"] + self.room_id2, _, _ = self.get_success( + self.room_creator.create_room(self.ourRequester, {}) + ) # Join the second user to the first room memberEvent, memberEventContext = self.get_success( diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 8794401823..f4c4661aaf 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -27,7 +27,6 @@ from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register_servlets_for_client_rest_resource, @@ -35,7 +34,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override] - # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastores().main @@ -48,7 +46,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.u_charlie = UserID.from_string("@charlie:elsewhere") def test_one_member(self) -> None: - # Alice creates the room, and is automatically joined self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 333a67e84d..8b30832a35 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -242,7 +242,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -259,7 +259,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -272,7 +272,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -289,7 +289,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -309,7 +309,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -327,7 +327,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -341,7 +341,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -392,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -404,7 +404,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertDictEqual({}, state_dict) room_id = self.room.to_string() - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -417,7 +417,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -428,7 +428,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -447,7 +447,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -459,7 +459,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -473,7 +473,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -485,7 +485,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + def test_batched_state_group_storing(self) -> None: + creation_event = self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, "", {} + ) + state_to_event = self.get_success( + self.storage.state.get_state_groups( + self.room.to_string(), [creation_event.event_id] + ) + ) + current_state_group = list(state_to_event.keys())[0] + + # create some unpersisted events and event contexts to store against room + events_and_context = [] + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Name, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"name": "first rename of room"}, + }, + ) + + event1, unpersisted_context1 = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + events_and_context.append((event1, unpersisted_context1)) + + builder2 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "private"}, + }, + ) + + event2, unpersisted_context2 = self.get_success( + self.event_creation_handler.create_new_client_event(builder2) + ) + events_and_context.append((event2, unpersisted_context2)) + + builder3 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Message, + "sender": self.u_alice.to_string(), + "room_id": self.room.to_string(), + "content": {"body": "hello from event 3", "msgtype": "m.text"}, + }, + ) + + event3, unpersisted_context3 = self.get_success( + self.event_creation_handler.create_new_client_event(builder3) + ) + events_and_context.append((event3, unpersisted_context3)) + + builder4 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "public"}, + }, + ) + + event4, unpersisted_context4 = self.get_success( + self.event_creation_handler.create_new_client_event(builder4) + ) + events_and_context.append((event4, unpersisted_context4)) + + processed_events_and_context = self.get_success( + self.hs.get_datastores().state.store_state_deltas_for_batched( + events_and_context, self.room.to_string(), current_state_group + ) + ) + + # check that only state events are in state_groups, and all state events are in state_groups + res = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ) + + events = [] + for result in res: + self.assertNotIn(event3.event_id, result) + events.append(result.get("event_id")) + + for event, _ in processed_events_and_context: + if event.is_state(): + self.assertIn(event.event_id, events) + + # check that each unique state has state group in state_groups_state and that the + # type/state key is correct, and check that each state event's state group + # has an entry and prev event in state_group_edges + for event, context in processed_events_and_context: + if event.is_state(): + state = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ) + self.assertEqual(event.type, state[0].get("type")) + self.assertEqual(event.state_key, state[0].get("state_key")) + + groups = self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={"state_group": str(context.state_group_after_event)}, + retcols=("*",), + ) + ) + self.assertEqual( + context.state_group_before_event, groups[0].get("prev_state_group") + ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index f1ca523d23..8c72aa1722 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.background_updates import _BackgroundUpdateHandler +from synapse.storage.databases.main import user_directory +from synapse.storage.databases.main.user_directory import ( + _parse_words_with_icu, + _parse_words_with_regex, +) from synapse.storage.roommember import ProfileInfo from synapse.util import Clock @@ -42,7 +47,7 @@ ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" # The localpart isn't 'Bela' on purpose so we can test looking up display names. -BELA = "@somenickname:a" +BELA = "@somenickname:example.org" class GetUserDirectoryTables: @@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryStoreTestCase(HomeserverTestCase): + use_icu = False + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main @@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None)) self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))) + self._restore_use_icu = user_directory.USE_ICU + user_directory.USE_ICU = self.use_icu + + def tearDown(self) -> None: + user_directory.USE_ICU = self._restore_use_icu + def test_search_user_dir(self) -> None: # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. @@ -478,6 +491,159 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, ) + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_start_of_user_id(self) -> None: + """Tests that a user can look up another user by searching for the start + of their user ID. + """ + r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_ascii_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case. + """ + CHARLIE = "@someuser:example.org" + self.get_success( + self.store.update_profile_in_user_dir(CHARLIE, "Charlie", None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, "cHARLIE", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": CHARLIE, "display_name": "Charlie", "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_unicode_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case. + """ + IVAN = "@someuser:example.org" + self.get_success(self.store.update_profile_in_user_dir(IVAN, "Иван", None)) + + r = self.get_success(self.store.search_user_dir(ALICE, "иВАН", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": IVAN, "display_name": "Иван", "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_dotted_dotless_i_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case, when their name contains dotted or dotless "i"s. + + Some languages have dotted and dotless versions of "i", which are considered to + be different letters: i <-> İ, ı <-> I. To make things difficult, they reuse the + ASCII "i" and "I" code points, despite having different lowercase / uppercase + forms. + """ + USER = "@someuser:example.org" + + expected_matches = [ + # (search_term, display_name) + # A search for "i" should match "İ". + ("iiiii", "İİİİİ"), + # A search for "I" should match "ı". + ("IIIII", "ııııı"), + # A search for "ı" should match "I". + ("ııııı", "IIIII"), + # A search for "İ" should match "i". + ("İİİİİ", "iiiii"), + ] + + for search_term, display_name in expected_matches: + self.get_success( + self.store.update_profile_in_user_dir(USER, display_name, None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10)) + self.assertFalse(r["limited"]) + self.assertEqual( + 1, + len(r["results"]), + f"searching for {search_term!r} did not match {display_name!r}", + ) + self.assertDictEqual( + r["results"][0], + {"user_id": USER, "display_name": display_name, "avatar_url": None}, + ) + + # We don't test for negative matches, to allow implementations that consider all + # the i variants to be the same. + + test_search_user_dir_dotted_dotless_i_case_insensitivity.skip = "not supported" # type: ignore + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_unicode_normalization(self) -> None: + """Tests that a user can look up another user by searching for their name with + either composed or decomposed accents. + """ + AMELIE = "@someuser:example.org" + + expected_matches = [ + # (search_term, display_name) + ("Ame\u0301lie", "Amélie"), + ("Amélie", "Ame\u0301lie"), + ] + + for search_term, display_name in expected_matches: + self.get_success( + self.store.update_profile_in_user_dir(AMELIE, display_name, None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10)) + self.assertFalse(r["limited"]) + self.assertEqual( + 1, + len(r["results"]), + f"searching for {search_term!r} did not match {display_name!r}", + ) + self.assertDictEqual( + r["results"][0], + {"user_id": AMELIE, "display_name": display_name, "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_accent_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name + without any accents. + """ + AMELIE = "@someuser:example.org" + self.get_success(self.store.update_profile_in_user_dir(AMELIE, "Amélie", None)) + + r = self.get_success(self.store.search_user_dir(ALICE, "amelie", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": AMELIE, "display_name": "Amélie", "avatar_url": None}, + ) + + # It may be desirable for "é"s in search terms to not match plain "e"s and we + # really don't want "é"s in search terms to match "e"s with different accents. + # But we don't test for this to allow implementations that consider all + # "e"-lookalikes to be the same. + + test_search_user_dir_accent_insensitivity.skip = "not supported yet" # type: ignore + + +class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase): + use_icu = True + + if not icu: + skip = "Requires PyICU" + class UserDirectoryICUTestCase(HomeserverTestCase): if not icu: @@ -513,3 +679,33 @@ class UserDirectoryICUTestCase(HomeserverTestCase): r["results"][0], {"user_id": ALICE, "display_name": display_name, "avatar_url": None}, ) + + def test_icu_word_boundary_punctuation(self) -> None: + """ + Tests the behaviour of punctuation with the ICU tokeniser. + + Seems to depend on underlying version of ICU. + """ + + # Note: either tokenisation is fine, because Postgres actually splits + # words itself afterwards. + self.assertIn( + _parse_words_with_icu("lazy'fox jumped:over the.dog"), + ( + # ICU 66 on Ubuntu 20.04 + ["lazy'fox", "jumped", "over", "the", "dog"], + # ICU 70 on Ubuntu 22.04 + ["lazy'fox", "jumped:over", "the.dog"], + # pyicu 2.10.2 on Alpine edge / macOS + ["lazy'fox", "jumped", "over", "the.dog"], + ), + ) + + def test_regex_word_boundary_punctuation(self) -> None: + """ + Tests the behaviour of punctuation with the non-ICU tokeniser + """ + self.assertEqual( + _parse_words_with_regex("lazy'fox jumped:over the.dog"), + ["lazy", "fox", "jumped", "over", "the", "dog"], + ) diff --git a/tests/test_federation.py b/tests/test_federation.py index 82dfd88b99..46d2f99eac 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -47,7 +47,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - )[0]["room_id"] + )[0] self.store = self.hs.get_datastores().main diff --git a/tests/test_mau.py b/tests/test_mau.py index 4e7665a22b..ff21098a59 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -32,7 +32,6 @@ from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): - servlets = [register.register_servlets, sync.register_servlets] def default_config(self) -> JsonDict: diff --git a/tests/unittest.py b/tests/unittest.py index 2d73911747..4b31f84494 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool -from twisted.test.proto_helpers import MemoryReactor +from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock from twisted.trial import unittest from twisted.web.resource import Resource from twisted.web.server import Request @@ -82,7 +82,7 @@ from tests.server import ( ) from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging -from tests.utils import default_config, setupdb +from tests.utils import checked_cast, default_config, setupdb setupdb() setup_logging() @@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase): from tests.rest.client.utils import RestHelper - self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) + self.helper = RestHelper( + self.hs, + checked_cast(MemoryReactorClock, self.hs.get_reactor()), + self.site, + getattr(self, "user_id", None), + ) if hasattr(self, "user_id"): if self.hijack_auth: @@ -718,7 +723,7 @@ class HomeserverTestCase(TestCase): event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) - event, context, _ = self.get_success( + event, unpersisted_context, _ = self.get_success( event_creator.create_event( requester, { @@ -730,7 +735,7 @@ class HomeserverTestCase(TestCase): prev_event_ids=prev_event_ids, ) ) - + context = self.get_success(unpersisted_context.persist(event)) if soft_failed: event.internal_metadata.soft_failed = True diff --git a/tests/utils.py b/tests/utils.py index a22f1d1241..3badfb7d41 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,7 +15,7 @@ import atexit import os -from typing import Any, Callable, Dict, List, Tuple, Union, overload +from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload import attr from typing_extensions import Literal, ParamSpec @@ -343,3 +343,27 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: context = await unpersisted_context.persist(event) await persistence_store.persist_event(event, context) + + +T = TypeVar("T") + + +def checked_cast(type: Type[T], x: object) -> T: + """A version of typing.cast that is checked at runtime. + + We have our own function for this for two reasons: + + 1. typing.cast itself is deliberately a no-op at runtime, see + https://docs.python.org/3/library/typing.html#typing.cast + 2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91 + where mypy would erroneously consider `isinstance(x, type)` to be false in all + circumstances. + + For this to make sense, `T` needs to be something that `isinstance` can check; see + https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance + https://docs.python.org/3/glossary.html#term-abstract-base-class + https://docs.python.org/3/library/typing.html#typing.runtime_checkable + for more details. + """ + assert isinstance(x, type) + return x |