diff options
97 files changed, 1232 insertions, 490 deletions
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 28fc6d45e6..f184727ced 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -566,6 +566,29 @@ jobs: - run: cargo test + # We want to ensure that the cargo benchmarks still compile, which requires a + # nightly compiler. + cargo-bench: + if: ${{ needs.changes.outputs.rust == 'true' }} + runs-on: ubuntu-latest + needs: + - linting-done + - changes + + steps: + - uses: actions/checkout@v3 + + - name: Install Rust + # There don't seem to be versioned releases of this action per se: for each rust + # version there is a branch which gets constantly rebased on top of master. + # We pin to a specific commit for paranoia's sake. + uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f + with: + toolchain: nightly-2022-12-01 + - uses: Swatinem/rust-cache@v2 + + - run: cargo bench --no-run + # a job which marks all the other jobs as complete, thus allowing PRs to be merged. tests-done: if: ${{ always() }} @@ -577,6 +600,7 @@ jobs: - portdb - complement - cargo-test + - cargo-bench runs-on: ubuntu-latest steps: - uses: matrix-org/done-action@v2 @@ -588,3 +612,4 @@ jobs: skippable: | lint-newsfile cargo-test + cargo-bench diff --git a/changelog.d/14823.feature b/changelog.d/14823.feature new file mode 100644 index 0000000000..8293e99eff --- /dev/null +++ b/changelog.d/14823.feature @@ -0,0 +1 @@ +Experimental support for [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952): intentional mentions. diff --git a/changelog.d/14866.bugfix b/changelog.d/14866.bugfix new file mode 100644 index 0000000000..540f918cbd --- /dev/null +++ b/changelog.d/14866.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.53.0 where `next_batch` tokens from `/sync` could not be used with the `/relations` endpoint. diff --git a/changelog.d/14879.misc b/changelog.d/14879.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14879.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/changelog.d/14880.bugfix b/changelog.d/14880.bugfix new file mode 100644 index 0000000000..e56c567082 --- /dev/null +++ b/changelog.d/14880.bugfix @@ -0,0 +1 @@ +Fix a bug when using the `send_local_online_presence_to` module API. diff --git a/changelog.d/14886.misc b/changelog.d/14886.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14886.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/changelog.d/14887.misc b/changelog.d/14887.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14887.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/changelog.d/14904.misc b/changelog.d/14904.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14904.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/changelog.d/14916.misc b/changelog.d/14916.misc new file mode 100644 index 0000000000..59914d4b8a --- /dev/null +++ b/changelog.d/14916.misc @@ -0,0 +1 @@ +Document how to handle Dependabot pull requests. diff --git a/changelog.d/14920.misc b/changelog.d/14920.misc new file mode 100644 index 0000000000..7988897d39 --- /dev/null +++ b/changelog.d/14920.misc @@ -0,0 +1 @@ +Fix typo in release script. diff --git a/changelog.d/14922.misc b/changelog.d/14922.misc new file mode 100644 index 0000000000..2cc3614dfd --- /dev/null +++ b/changelog.d/14922.misc @@ -0,0 +1 @@ +Use `StrCollection` to avoid potential bugs with `Collection[str]`. diff --git a/changelog.d/14926.bugfix b/changelog.d/14926.bugfix new file mode 100644 index 0000000000..f1f34cd6ba --- /dev/null +++ b/changelog.d/14926.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in Synapse 1.69.0 which can result in database corruption when database migrations are interrupted on sqlite. diff --git a/changelog.d/14927.misc b/changelog.d/14927.misc new file mode 100644 index 0000000000..9f5384e60e --- /dev/null +++ b/changelog.d/14927.misc @@ -0,0 +1 @@ +Add missing type hints. \ No newline at end of file diff --git a/changelog.d/14935.misc b/changelog.d/14935.misc new file mode 100644 index 0000000000..0ad74b90eb --- /dev/null +++ b/changelog.d/14935.misc @@ -0,0 +1 @@ +Bump ijson from 3.1.4 to 3.2.0.post0. diff --git a/changelog.d/14936.misc b/changelog.d/14936.misc new file mode 100644 index 0000000000..bd5001173f --- /dev/null +++ b/changelog.d/14936.misc @@ -0,0 +1 @@ +Bump types-pyyaml from 6.0.12.2 to 6.0.12.3. diff --git a/changelog.d/14937.misc b/changelog.d/14937.misc new file mode 100644 index 0000000000..061568f010 --- /dev/null +++ b/changelog.d/14937.misc @@ -0,0 +1 @@ +Bump types-jsonschema from 4.17.0.2 to 4.17.0.3. diff --git a/changelog.d/14938.misc b/changelog.d/14938.misc new file mode 100644 index 0000000000..f2fbabe8bb --- /dev/null +++ b/changelog.d/14938.misc @@ -0,0 +1 @@ +Bump types-pillow from 9.4.0.3 to 9.4.0.5. diff --git a/changelog.d/14942.bugfix b/changelog.d/14942.bugfix new file mode 100644 index 0000000000..a3ca3eb7e9 --- /dev/null +++ b/changelog.d/14942.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.68.0 where we were unable to service remote joins in rooms with `@room` notification levels set to `null` in their (malformed) power levels. diff --git a/changelog.d/14943.feature b/changelog.d/14943.feature new file mode 100644 index 0000000000..8293e99eff --- /dev/null +++ b/changelog.d/14943.feature @@ -0,0 +1 @@ +Experimental support for [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952): intentional mentions. diff --git a/changelog.d/14944.bugfix b/changelog.d/14944.bugfix new file mode 100644 index 0000000000..5fe1fb322b --- /dev/null +++ b/changelog.d/14944.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.64 where boolean power levels were erroneously permitted in [v10 rooms](https://spec.matrix.org/v1.5/rooms/v10/). diff --git a/changelog.d/14945.misc b/changelog.d/14945.misc new file mode 100644 index 0000000000..654174f9a8 --- /dev/null +++ b/changelog.d/14945.misc @@ -0,0 +1 @@ +Fix various long-standing bugs in Synapse's config, event and request handling where booleans were unintentionally accepted where an integer was expected. diff --git a/changelog.d/14947.bugfix b/changelog.d/14947.bugfix new file mode 100644 index 0000000000..b9e768c44c --- /dev/null +++ b/changelog.d/14947.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where sending messages on servers with presence enabled would spam "Re-starting finished log context" log lines. diff --git a/changelog.d/14950.misc b/changelog.d/14950.misc new file mode 100644 index 0000000000..6602776b3f --- /dev/null +++ b/changelog.d/14950.misc @@ -0,0 +1 @@ +Faster joins: tag `v2/send_join/` requests to indicate if they served a partial join response. diff --git a/docs/development/dependencies.md b/docs/development/dependencies.md index b734cc5826..c4449c51f7 100644 --- a/docs/development/dependencies.md +++ b/docs/development/dependencies.md @@ -258,6 +258,20 @@ because [`build`](https://github.com/pypa/build) is a standardish tool which doesn't require poetry. (It's what we use in CI too). However, you could try `poetry build` too. +## ...handle a Dependabot pull request? + +Synapse uses Dependabot to keep the `poetry.lock` file up-to-date. When it +creates a pull request a GitHub Action will run to automatically create a changelog +file. Ensure that: + +* the lockfile changes look reasonable; +* the upstream changelog file (linked in the description) doesn't include any + breaking changes; +* continuous integration passes (due to permissions, the GitHub Actions run on + the changelog commit will fail, look at the initial commit of the pull request); + +In particular, any updates to the type hints (usually packages which start with `types-`) +should be safe to merge if linting passes. # Troubleshooting diff --git a/mypy.ini b/mypy.ini index 63366dad99..978d92940b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -35,19 +35,12 @@ exclude = (?x) |tests/api/test_auth.py |tests/app/test_openid_listener.py |tests/appservice/test_scheduler.py - |tests/events/test_presence_router.py - |tests/events/test_utils.py |tests/federation/test_federation_catch_up.py |tests/federation/test_federation_sender.py - |tests/federation/transport/test_knocking.py - |tests/handlers/test_typing.py |tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_srv_resolver.py |tests/http/test_proxyagent.py - |tests/logging/__init__.py - |tests/logging/test_terse_json.py |tests/module_api/test_api.py - |tests/rest/client/test_transactions.py |tests/rest/media/v1/test_media_storage.py |tests/server.py |tests/test_state.py @@ -87,12 +80,18 @@ disallow_untyped_defs = True [mypy-tests.crypto.*] disallow_untyped_defs = True +[mypy-tests.events.*] +disallow_untyped_defs = True + [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True [mypy-tests.handlers.*] disallow_untyped_defs = True +[mypy-tests.logging.*] +disallow_untyped_defs = True + [mypy-tests.metrics.*] disallow_untyped_defs = True diff --git a/poetry.lock b/poetry.lock index 17a6645b55..99628643f0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -579,74 +579,90 @@ files = [ [[package]] name = "ijson" -version = "3.1.4" +version = "3.2.0.post0" description = "Iterative JSON parser with standard Python iterator interfaces" category = "main" optional = false python-versions = "*" files = [ - {file = "ijson-3.1.4-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:6c1a777096be5f75ffebb335c6d2ebc0e489b231496b7f2ca903aa061fe7d381"}, - {file = "ijson-3.1.4-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:475fc25c3d2a86230b85777cae9580398b42eed422506bf0b6aacfa936f7bfcd"}, - {file = "ijson-3.1.4-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:f587699b5a759e30accf733e37950cc06c4118b72e3e146edcea77dded467426"}, - {file = "ijson-3.1.4-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:339b2b4c7bbd64849dd69ef94ee21e29dcd92c831f47a281fdd48122bb2a715a"}, - {file = "ijson-3.1.4-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:446ef8980504da0af8d20d3cb6452c4dc3d8aa5fd788098985e899b913191fe6"}, - {file = "ijson-3.1.4-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:3997a2fdb28bc04b9ab0555db5f3b33ed28d91e9d42a3bf2c1842d4990beb158"}, - {file = "ijson-3.1.4-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:fa10a1d88473303ec97aae23169d77c5b92657b7fb189f9c584974c00a79f383"}, - {file = "ijson-3.1.4-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:9a5bf5b9d8f2ceaca131ee21fc7875d0f34b95762f4f32e4d65109ca46472147"}, - {file = "ijson-3.1.4-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:81cc8cee590c8a70cca3c9aefae06dd7cb8e9f75f3a7dc12b340c2e332d33a2a"}, - {file = "ijson-3.1.4-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4ea5fc50ba158f72943d5174fbc29ebefe72a2adac051c814c87438dc475cf78"}, - {file = "ijson-3.1.4-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:3b98861a4280cf09d267986cefa46c3bd80af887eae02aba07488d80eb798afa"}, - {file = "ijson-3.1.4-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:068c692efba9692406b86736dcc6803e4a0b6280d7f0b7534bff3faec677ff38"}, - {file = "ijson-3.1.4-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:86884ac06ac69cea6d89ab7b84683b3b4159c4013e4a20276d3fc630fe9b7588"}, - {file = "ijson-3.1.4-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:41e5886ff6fade26f10b87edad723d2db14dcbb1178717790993fcbbb8ccd333"}, - {file = "ijson-3.1.4-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:24b58933bf777d03dc1caa3006112ec7f9e6f6db6ffe1f5f5bd233cb1281f719"}, - {file = "ijson-3.1.4-cp35-cp35m-manylinux2014_aarch64.whl", hash = "sha256:13f80aad0b84d100fb6a88ced24bade21dc6ddeaf2bba3294b58728463194f50"}, - {file = "ijson-3.1.4-cp35-cp35m-win32.whl", hash = "sha256:fa9a25d0bd32f9515e18a3611690f1de12cb7d1320bd93e9da835936b41ad3ff"}, - {file = "ijson-3.1.4-cp35-cp35m-win_amd64.whl", hash = "sha256:c4c1bf98aaab4c8f60d238edf9bcd07c896cfcc51c2ca84d03da22aad88957c5"}, - {file = "ijson-3.1.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f0f2a87c423e8767368aa055310024fa28727f4454463714fef22230c9717f64"}, - {file = "ijson-3.1.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:15507de59d74d21501b2a076d9c49abf927eb58a51a01b8f28a0a0565db0a99f"}, - {file = "ijson-3.1.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:2e6bd6ad95ab40c858592b905e2bbb4fe79bbff415b69a4923dafe841ffadcb4"}, - {file = "ijson-3.1.4-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:68e295bb12610d086990cedc89fb8b59b7c85740d66e9515aed062649605d0bf"}, - {file = "ijson-3.1.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3bb461352c0f0f2ec460a4b19400a665b8a5a3a2da663a32093df1699642ee3f"}, - {file = "ijson-3.1.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:f91c75edd6cf1a66f02425bafc59a22ec29bc0adcbc06f4bfd694d92f424ceb3"}, - {file = "ijson-3.1.4-cp36-cp36m-win32.whl", hash = "sha256:4c53cc72f79a4c32d5fc22efb85aa22f248e8f4f992707a84bdc896cc0b1ecf9"}, - {file = "ijson-3.1.4-cp36-cp36m-win_amd64.whl", hash = "sha256:ac9098470c1ff6e5c23ec0946818bc102bfeeeea474554c8d081dc934be20988"}, - {file = "ijson-3.1.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dcd6f04df44b1945b859318010234651317db2c4232f75e3933f8bb41c4fa055"}, - {file = "ijson-3.1.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:5a2f40c053c837591636dc1afb79d85e90b9a9d65f3d9963aae31d1eb11bfed2"}, - {file = "ijson-3.1.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:f50337e3b8e72ec68441b573c2848f108a8976a57465c859b227ebd2a2342901"}, - {file = "ijson-3.1.4-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:454918f908abbed3c50a0a05c14b20658ab711b155e4f890900e6f60746dd7cc"}, - {file = "ijson-3.1.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:387c2ec434cc1bc7dc9bd33ec0b70d95d443cc1e5934005f26addc2284a437ab"}, - {file = "ijson-3.1.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:179ed6fd42e121d252b43a18833df2de08378fac7bce380974ef6f5e522afefa"}, - {file = "ijson-3.1.4-cp37-cp37m-win32.whl", hash = "sha256:26a6a550b270df04e3f442e2bf0870c9362db4912f0e7bdfd300f30ea43115a2"}, - {file = "ijson-3.1.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ff8cf7507d9d8939264068c2cff0a23f99703fa2f31eb3cb45a9a52798843586"}, - {file = "ijson-3.1.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:09c9d7913c88a6059cd054ff854958f34d757402b639cf212ffbec201a705a0d"}, - {file = "ijson-3.1.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:702ba9a732116d659a5e950ee176be6a2e075998ef1bcde11cbf79a77ed0f717"}, - {file = "ijson-3.1.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:667841591521158770adc90793c2bdbb47c94fe28888cb802104b8bbd61f3d51"}, - {file = "ijson-3.1.4-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:df641dd07b38c63eecd4f454db7b27aa5201193df160f06b48111ba97ab62504"}, - {file = "ijson-3.1.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:9348e7d507eb40b52b12eecff3d50934fcc3d2a15a2f54ec1127a36063b9ba8f"}, - {file = "ijson-3.1.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:93455902fdc33ba9485c7fae63ac95d96e0ab8942224a357113174bbeaff92e9"}, - {file = "ijson-3.1.4-cp38-cp38-win32.whl", hash = "sha256:5b725f2e984ce70d464b195f206fa44bebbd744da24139b61fec72de77c03a16"}, - {file = "ijson-3.1.4-cp38-cp38-win_amd64.whl", hash = "sha256:a5965c315fbb2dc9769dfdf046eb07daf48ae20b637da95ec8d62b629be09df4"}, - {file = "ijson-3.1.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b8ee7dbb07cec9ba29d60cfe4954b3cc70adb5f85bba1f72225364b59c1cf82b"}, - {file = "ijson-3.1.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:d9e01c55d501e9c3d686b6ee3af351c9c0c8c3e45c5576bd5601bee3e1300b09"}, - {file = "ijson-3.1.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:297f26f27a04cd0d0a2f865d154090c48ea11b239cabe0a17a6c65f0314bd1ca"}, - {file = "ijson-3.1.4-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:9239973100338a4138d09d7a4602bd289861e553d597cd67390c33bfc452253e"}, - {file = "ijson-3.1.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:2a64c66a08f56ed45a805691c2fd2e1caef00edd6ccf4c4e5eff02cd94ad8364"}, - {file = "ijson-3.1.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d17fd199f0d0a4ab6e0d541b4eec1b68b5bd5bb5d8104521e22243015b51049b"}, - {file = "ijson-3.1.4-cp39-cp39-win32.whl", hash = "sha256:70ee3c8fa0eba18c80c5911639c01a8de4089a4361bad2862a9949e25ec9b1c8"}, - {file = "ijson-3.1.4-cp39-cp39-win_amd64.whl", hash = "sha256:6bf2b64304321705d03fa5e403ec3f36fa5bb27bf661849ad62e0a3a49bc23e3"}, - {file = "ijson-3.1.4-pp27-pypy_73-macosx_10_9_x86_64.whl", hash = "sha256:5d7e3fcc3b6de76a9dba1e9fc6ca23dad18f0fa6b4e6499415e16b684b2e9af1"}, - {file = "ijson-3.1.4-pp27-pypy_73-manylinux1_x86_64.whl", hash = "sha256:a72eb0359ebff94754f7a2f00a6efe4c57716f860fc040c606dedcb40f49f233"}, - {file = "ijson-3.1.4-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:28fc168f5faf5759fdfa2a63f85f1f7a148bbae98f34404a6ba19f3d08e89e87"}, - {file = "ijson-3.1.4-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2844d4a38d27583897ed73f7946e205b16926b4cab2525d1ce17e8b08064c706"}, - {file = "ijson-3.1.4-pp36-pypy36_pp73-manylinux1_x86_64.whl", hash = "sha256:252defd1f139b5fb8c764d78d5e3a6df81543d9878c58992a89b261369ea97a7"}, - {file = "ijson-3.1.4-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:15d5356b4d090c699f382c8eb6a2bcd5992a8c8e8b88c88bc6e54f686018328a"}, - {file = "ijson-3.1.4-pp36-pypy36_pp73-win32.whl", hash = "sha256:6774ec0a39647eea70d35fb76accabe3d71002a8701c0545b9120230c182b75b"}, - {file = "ijson-3.1.4-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f11da15ec04cc83ff0f817a65a3392e169be8d111ba81f24d6e09236597bb28c"}, - {file = "ijson-3.1.4-pp37-pypy37_pp73-manylinux1_x86_64.whl", hash = "sha256:ee13ceeed9b6cf81b3b8197ef15595fc43fd54276842ed63840ddd49db0603da"}, - {file = "ijson-3.1.4-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:97e4df67235fae40d6195711223520d2c5bf1f7f5087c2963fcde44d72ebf448"}, - {file = "ijson-3.1.4-pp37-pypy37_pp73-win32.whl", hash = "sha256:3d10eee52428f43f7da28763bb79f3d90bbbeea1accb15de01e40a00885b6e89"}, - {file = "ijson-3.1.4.tar.gz", hash = "sha256:1d1003ae3c6115ec9b587d29dd136860a81a23c7626b682e2b5b12c9fd30e4ea"}, + {file = "ijson-3.2.0.post0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5809752045ef74c26adf159ed03df7fb7e7a8d656992fd7562663ed47d6d39d9"}, + {file = "ijson-3.2.0.post0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ce4be2beece2629bd24bcab147741d1532bd5ed40fb52f2b4fcde5c5bf606df0"}, + {file = "ijson-3.2.0.post0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5d365df54d18076f1d5f2ffb1eef2ac7f0d067789838f13d393b5586fbb77b02"}, + {file = "ijson-3.2.0.post0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c93ae4d49d8cf8accfedc8a8e7815851f56ceb6e399b0c186754a68fed22844"}, + {file = "ijson-3.2.0.post0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47a56e3628c227081a2aa58569cbf2af378bad8af648aa904080e87cd6644cfb"}, + {file = "ijson-3.2.0.post0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8af68fe579f6f0b9a8b3f033d10caacfed6a4b89b8c7a1d9478a8f5d8aba4a1"}, + {file = "ijson-3.2.0.post0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6eed1ddd3147de49226db4f213851cf7860493a7b6c7bd5e62516941c007094c"}, + {file = "ijson-3.2.0.post0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9ecbf85a6d73fc72f6534c38f7d92ed15d212e29e0dbe9810a465d61c8a66d23"}, + {file = "ijson-3.2.0.post0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd218b338ac68213c997d4c88437c0e726f16d301616bf837e1468901934042c"}, + {file = "ijson-3.2.0.post0-cp310-cp310-win32.whl", hash = "sha256:4e7c4fdc7d24747c8cc7d528c145afda4de23210bf4054bd98cd63bf07e4882d"}, + {file = "ijson-3.2.0.post0-cp310-cp310-win_amd64.whl", hash = "sha256:4d4e143908f47307042c9678803d27706e0e2099d0a6c1988c6cae1da07760bf"}, + {file = "ijson-3.2.0.post0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:56500dac8f52989ef7c0075257a8b471cbea8ef77f1044822742b3cbf2246e8b"}, + {file = "ijson-3.2.0.post0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:535665a77408b6bea56eb828806fae125846dff2e2e0ed4cb2e0a8e36244d753"}, + {file = "ijson-3.2.0.post0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a4465c90b25ca7903410fabe4145e7b45493295cc3b84ec1216653fbe9021276"}, + {file = "ijson-3.2.0.post0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efee1e9b4f691e1086730f3010e31c55625bc2e0f7db292a38a2cdf2774c2e13"}, + {file = "ijson-3.2.0.post0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fd55f7a46429de95383fc0d0158c1bfb798e976d59d52830337343c2d9bda5c"}, + {file = "ijson-3.2.0.post0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25919b444426f58dcc62f763d1c6be6297f309da85ecab55f51da6ca86fc9fdf"}, + {file = "ijson-3.2.0.post0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c85892d68895ba7a0b16a0e6b7d9f9a0e30e86f2b1e0f6986243473ba8735432"}, + {file = "ijson-3.2.0.post0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:27409ba44cfd006901971063d37699f72e092b5efaa1586288b5067d80c6b5bd"}, + {file = "ijson-3.2.0.post0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:11dfd64633fe1382c4237477ac3836f682ca17e25e0d0799e84737795b0611df"}, + {file = "ijson-3.2.0.post0-cp311-cp311-win32.whl", hash = "sha256:41e955e173f77f54337fecaaa58a35c464b75e232b1f939b282497134a4d4f0e"}, + {file = "ijson-3.2.0.post0-cp311-cp311-win_amd64.whl", hash = "sha256:b3bdd2e12d9b9a18713dd6f3c5ef3734fdab25b79b177054ba9e35ecc746cb6e"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:26b57838e712b8852c40ec6d74c6de8bb226446440e1af1354c077a6f81b9142"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6464242f7895268d3086d7829ef031b05c77870dad1e13e51ef79d0a9cfe029"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3c6cf18b61b94db9590f86af0dd60edbccb36e151643152b8688066f677fbc9"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:992e9e68003df32e2aa0f31eb82c0a94f21286203ab2f2b2c666410e17b59d2f"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:d3e255ef05b434f20fc9d4b18ea15733d1038bec3e4960d772b06216fa79e82d"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:424232c2bf3e8181f1b572db92c179c2376b57eba9fc8931453fba975f48cb80"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bced6cd5b09d4d002dda9f37292dd58d26eb1c4d0d179b820d3708d776300bb4"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-win32.whl", hash = "sha256:a8c84dff2d60ae06d5280ec87cd63050bbd74a90c02bfc7c390c803cfc8ac8fc"}, + {file = "ijson-3.2.0.post0-cp36-cp36m-win_amd64.whl", hash = "sha256:a340413a9bf307fafd99254a4dd4ac6c567b91a205bf896dde18888315fd7fcd"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b3456cd5b16ec9db3ef23dd27f37bf5a14f765e8272e9af3e3de9ee9a4cba867"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0eb838b4e4360e65c00aa13c78b35afc2477759d423b602b60335af5bed3de5b"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe7f414edd69dd9199b0dfffa0ada22f23d8009e10fe2a719e0993b7dcc2e6e2"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:183841b8d033ca95457f61fb0719185dc7f51a616070bdf1dcaf03473bed05b2"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1302dc6490da7d44c3a76a5f0b87d8bec9f918454c6d6e6bf4ed922e47da58bb"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:3b21b1ecd20ed2f918f6f99cdfa68284a416c0f015ffa64b68fa933df1b24d40"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e97e6e07851cefe7baa41f1ebf5c0899d2d00d94bfef59825752e4c784bebbe8"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-win32.whl", hash = "sha256:cd0450e76b9c629b7f86e7d5b91b7cc9c281dd719630160a992b19a856f7bdbd"}, + {file = "ijson-3.2.0.post0-cp37-cp37m-win_amd64.whl", hash = "sha256:bed8dcb7dbfdb98e647ad47676045e0891f610d38095dcfdae468e1e1efb2766"}, + {file = "ijson-3.2.0.post0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a7698bc480df76073067017f73ba4139dbaae20f7a6c9a0c7855b9c5e9a62124"}, + {file = "ijson-3.2.0.post0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2f204f6d4cedeb28326c230a0b046968b5263c234c65a5b18cee22865800fff7"}, + {file = "ijson-3.2.0.post0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9829a17f6f78d7f4d0aeff28c126926a1e5f86828ebb60d6a0acfa0d08457f9f"}, + {file = "ijson-3.2.0.post0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f470f3d750e00df86e03254fdcb422d2f726f4fb3a0d8eeee35e81343985e58a"}, + {file = "ijson-3.2.0.post0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb167ee21d9c413d6b0ab65ec12f3e7ea0122879da8b3569fa1063526f9f03a8"}, + {file = "ijson-3.2.0.post0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84eed88177f6c243c52b280cb094f751de600d98d2221e0dec331920894889ec"}, + {file = "ijson-3.2.0.post0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:53f1a13eb99ab514c562869513172135d4b55a914b344e6518ba09ad3ef1e503"}, + {file = "ijson-3.2.0.post0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f6785ba0f65eb64b1ce3b7fcfec101085faf98f4e77b234f14287fd4138ffb25"}, + {file = "ijson-3.2.0.post0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:79b94662c2e9d366ab362c2c5858097eae0da100dea0dfd340db09ab28c8d5e8"}, + {file = "ijson-3.2.0.post0-cp38-cp38-win32.whl", hash = "sha256:5242cb2313ba3ece307b426efa56424ac13cc291c36f292b501d412a98ad0703"}, + {file = "ijson-3.2.0.post0-cp38-cp38-win_amd64.whl", hash = "sha256:775444a3b647350158d0b3c6c39c88b4a0995643a076cb104bf25042c9aedcf8"}, + {file = "ijson-3.2.0.post0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1d64ffaab1d006a4fa9584a4c723e95cc9609bf6c3365478e250cd0bffaaadf3"}, + {file = "ijson-3.2.0.post0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:434e57e7ec5c334ccb0e67bb4d9e60c264dcb2a3843713dbeb12cb19fe42a668"}, + {file = "ijson-3.2.0.post0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:158494bfe89ccb32618d0e53b471364080ceb975462ec464d9f9f37d9832b653"}, + {file = "ijson-3.2.0.post0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f20072376e338af0e51ccecb02335b4e242d55a9218a640f545be7fc64cca99"}, + {file = "ijson-3.2.0.post0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3e8d46c1004afcf2bf513a8fb575ee2ec3d8009a2668566b5926a2dcf7f1a45"}, + {file = "ijson-3.2.0.post0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:986a0347fe19e5117a5241276b72add570839e5bcdc7a6dac4b538c5928eeff5"}, + {file = "ijson-3.2.0.post0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:535a59d61b9aef6fc2a3d01564c1151e38e5a44b92cd6583cb4e8ccf0f58043f"}, + {file = "ijson-3.2.0.post0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:830de03f391f7e72b8587bb178c22d534da31153e9ee4234d54ef82cde5ace5e"}, + {file = "ijson-3.2.0.post0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6def9ac8d73b76cb02e9e9837763f27f71e5e67ec0afae5f1f4cf8f61c39b1ac"}, + {file = "ijson-3.2.0.post0-cp39-cp39-win32.whl", hash = "sha256:11bb84a53c37e227e733c6dffad2037391cf0b3474bff78596dc4373b02008a0"}, + {file = "ijson-3.2.0.post0-cp39-cp39-win_amd64.whl", hash = "sha256:f349bee14d0a4a72ba41e1b1cce52af324ebf704f5066c09e3dd04cfa6f545f0"}, + {file = "ijson-3.2.0.post0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5418066666b25b05f2b8ae2698408daa0afa68f07b0b217f2ab24465b7e9cbd9"}, + {file = "ijson-3.2.0.post0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ccc4d4b947549f9c431651c02b95ef571412c78f88ded198612a41d5c5701a0"}, + {file = "ijson-3.2.0.post0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dcec67fc15e5978ad286e8cc2a3f9347076e28e0e01673b5ace18c73da64e3ff"}, + {file = "ijson-3.2.0.post0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ee9537e8a8aa15dd2d0912737aeb6265e781e74f7f7cad8165048fcb5f39230"}, + {file = "ijson-3.2.0.post0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:03dfd4c8ed19e704d04b0ad4f34f598dc569fd3f73089f80eed698e7f6069233"}, + {file = "ijson-3.2.0.post0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2d50b2ad9c6c51ca160aa60de7f4dacd1357c38d0e503f51aed95c1c1945ff53"}, + {file = "ijson-3.2.0.post0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51c1db80d7791fb761ad9a6c70f521acd2c4b0e5afa2fe0d813beb2140d16c37"}, + {file = "ijson-3.2.0.post0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:13f2939db983327dd0492f6c1c0e77be3f2cbf9b620c92c7547d1d2cd6ef0486"}, + {file = "ijson-3.2.0.post0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f9d449f86f8971c24609e319811f7f3b6b734f0218c4a0e799debe19300d15b"}, + {file = "ijson-3.2.0.post0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7e0d1713a9074a7677eb8e43f424b731589d1c689d4676e2f57a5ce59d089e89"}, + {file = "ijson-3.2.0.post0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c8646eb81eec559d7d8b1e51a5087299d06ecab3bc7da54c01f7df94350df135"}, + {file = "ijson-3.2.0.post0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fe3a53e00c59de33b825ba8d6d39f544a7d7180983cd3d6bd2c3794ae35442"}, + {file = "ijson-3.2.0.post0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93aaec00cbde65c192f15c21f3ee44d2ab0c11eb1a35020b5c4c2676f7fe01d0"}, + {file = "ijson-3.2.0.post0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00594ed3ef2218fee8c652d9e7f862fb39f8251b67c6379ef12f7e044bf6bbf3"}, + {file = "ijson-3.2.0.post0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:1a75cfb34217b41136b714985be645f12269e4345da35d7b48aabd317c82fd10"}, + {file = "ijson-3.2.0.post0.tar.gz", hash = "sha256:80a5bd7e9923cab200701f67ad2372104328b99ddf249dbbe8834102c852d316"}, ] [[package]] @@ -2558,14 +2574,14 @@ files = [ [[package]] name = "types-jsonschema" -version = "4.17.0.2" +version = "4.17.0.3" description = "Typing stubs for jsonschema" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-jsonschema-4.17.0.2.tar.gz", hash = "sha256:8b9e1140d4d780f0f19b5cab1b8a3732e8dd5e49dbc1f174cc0b499125ca6f6c"}, - {file = "types_jsonschema-4.17.0.2-py3-none-any.whl", hash = "sha256:8fd2f9aea4da54f9a811baa6963aac10fd680c18baa6237392c079b97d152738"}, + {file = "types-jsonschema-4.17.0.3.tar.gz", hash = "sha256:746aa466ffed9a1acc7bdbd0ac0b5e068f00be2ee008c1d1e14b0944a8c8b24b"}, + {file = "types_jsonschema-4.17.0.3-py3-none-any.whl", hash = "sha256:c8d5b26b7c8da6a48d7fb1ce029b97e0ff6e74db3727efb968c69f39ad013685"}, ] [[package]] @@ -2582,14 +2598,14 @@ files = [ [[package]] name = "types-pillow" -version = "9.4.0.3" +version = "9.4.0.5" description = "Typing stubs for Pillow" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-Pillow-9.4.0.3.tar.gz", hash = "sha256:eba8ff24457a1b8669b6099793f3d313d034d407ee9f6e5fdf12c86cf54914cd"}, - {file = "types_Pillow-9.4.0.3-py3-none-any.whl", hash = "sha256:f8f16a54ed315144296864df11f14beca82ec0990ea83710b7eac7eb1bb38971"}, + {file = "types-Pillow-9.4.0.5.tar.gz", hash = "sha256:941cefaac2f5297d7d2a9989633c95b4063112690dc21c965d46bd5a7fff3c76"}, + {file = "types_Pillow-9.4.0.5-py3-none-any.whl", hash = "sha256:a1d2b3e070b4d852af04f76f018d12bd51abb4abca3b725d91b35e01cda7a2de"}, ] [[package]] @@ -2621,14 +2637,14 @@ types-cryptography = "*" [[package]] name = "types-pyyaml" -version = "6.0.12.2" +version = "6.0.12.3" description = "Typing stubs for PyYAML" category = "dev" optional = false python-versions = "*" files = [ - {file = "types-PyYAML-6.0.12.2.tar.gz", hash = "sha256:6840819871c92deebe6a2067fb800c11b8a063632eb4e3e755914e7ab3604e83"}, - {file = "types_PyYAML-6.0.12.2-py3-none-any.whl", hash = "sha256:1e94e80aafee07a7e798addb2a320e32956a373f376655128ae20637adb2655b"}, + {file = "types-PyYAML-6.0.12.3.tar.gz", hash = "sha256:17ce17b3ead8f06e416a3b1d5b8ddc6cb82a422bb200254dd8b469434b045ffc"}, + {file = "types_PyYAML-6.0.12.3-py3-none-any.whl", hash = "sha256:879700e9f215afb20ab5f849590418ab500989f83a57e635689e1d50ccc63f0c"}, ] [[package]] diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 8c28bb0af3..6b16a3f75b 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -13,6 +13,7 @@ // limitations under the License. #![feature(test)] +use std::collections::BTreeSet; use synapse::push::{ evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, }; @@ -32,6 +33,8 @@ fn bench_match_exact(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -68,6 +71,8 @@ fn bench_match_word(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -104,6 +109,8 @@ fn bench_match_word_miss(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -140,6 +147,8 @@ fn bench_eval_message(b: &mut Bencher) { let eval = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), Default::default(), @@ -156,6 +165,7 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, + false, ); b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 9140a69bb6..880eed0ef4 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -132,6 +132,14 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default_enabled: true, }, PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_user_mentioned"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsUserMention)]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::ContainsDisplayName)]), @@ -140,6 +148,19 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default_enabled: true, }, PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mentioned"), + priority_class: 5, + conditions: Cow::Borrowed(&[ + Condition::Known(KnownCondition::IsRoomMention), + Condition::Known(KnownCondition::SenderNotificationPermission { + key: Cow::Borrowed("room"), + }), + ]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, + PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"), priority_class: 5, conditions: Cow::Borrowed(&[ diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 0242ee1c5f..aa71202e43 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -68,6 +68,11 @@ pub struct PushRuleEvaluator { /// The "content.body", if any. body: String, + /// The user mentions that were part of the message. + user_mentions: BTreeSet<String>, + /// True if the message is a room message. + room_mention: bool, + /// The number of users in the room. room_member_count: u64, @@ -100,6 +105,8 @@ impl PushRuleEvaluator { #[new] pub fn py_new( flattened_keys: BTreeMap<String, String>, + user_mentions: BTreeSet<String>, + room_mention: bool, room_member_count: u64, sender_power_level: Option<i64>, notification_power_levels: BTreeMap<String, i64>, @@ -116,6 +123,8 @@ impl PushRuleEvaluator { Ok(PushRuleEvaluator { flattened_keys, body, + user_mentions, + room_mention, room_member_count, notification_power_levels, sender_power_level, @@ -229,6 +238,14 @@ impl PushRuleEvaluator { KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } + KnownCondition::IsUserMention => { + if let Some(uid) = user_id { + self.user_mentions.contains(uid) + } else { + false + } + } + KnownCondition::IsRoomMention => self.room_mention, KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -424,6 +441,8 @@ fn push_rule_evaluator() { flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); let evaluator = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), BTreeMap::new(), @@ -449,6 +468,8 @@ fn test_requires_room_version_supports_condition() { let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), + false, 10, Some(0), BTreeMap::new(), @@ -483,7 +504,7 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true), + &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false), None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 842b13c88b..7e449f2433 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -269,6 +269,10 @@ pub enum KnownCondition { EventMatch(EventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), + #[serde(rename = "org.matrix.msc3952.is_user_mention")] + IsUserMention, + #[serde(rename = "org.matrix.msc3952.is_room_mention")] + IsRoomMention, ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -414,6 +418,7 @@ pub struct FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3952_intentional_mentions: bool, } #[pymethods] @@ -425,6 +430,7 @@ impl FilteredPushRules { msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3952_intentional_mentions: bool, ) -> Self { Self { push_rules, @@ -432,6 +438,7 @@ impl FilteredPushRules { msc1767_enabled, msc3381_polls_enabled, msc3664_enabled, + msc3952_intentional_mentions, } } @@ -465,6 +472,11 @@ impl FilteredPushRules { return false; } + if !self.msc3952_intentional_mentions && rule.rule_id.contains("org.matrix.msc3952") + { + return false; + } + true }) .map(|r| { @@ -523,6 +535,28 @@ fn test_deserialize_unstable_msc3931_condition() { } #[test] +fn test_deserialize_unstable_msc3952_user_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsUserMention) + )); +} + +#[test] +fn test_deserialize_unstable_msc3952_room_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsRoomMention) + )); +} + +#[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 6974fd7895..008a5bd965 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -438,7 +438,7 @@ def _upload(gh_token: Optional[str]) -> None: repo = get_repo_and_check_clean_checkout() tag = repo.tag(f"refs/tags/{tag_name}") if repo.head.commit != tag.commit: - click.echo("Tag {tag_name} (tag.commit) is not currently checked out!") + click.echo(f"Tag {tag_name} ({tag.commit}) is not currently checked out!") click.get_current_context().abort() # Query all the assets corresponding to this release. diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 304ed7111c..588d90c25a 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from synapse.types import JsonDict @@ -46,6 +46,7 @@ class FilteredPushRules: msc1767_enabled: bool, msc3381_polls_enabled: bool, msc3664_enabled: bool, + msc3952_intentional_mentions: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -55,6 +56,8 @@ class PushRuleEvaluator: def __init__( self, flattened_keys: Mapping[str, str], + user_mentions: Set[str], + room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6432d32d83..0f224b34cd 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +import enum + from typing_extensions import Final # the max size of a (canonical-json-encoded) event @@ -231,6 +233,9 @@ class EventContentFields: # The authorising user for joining a restricted room. AUTHORISING_USER: Final = "join_authorised_via_users_server" + # Use for mentioning users. + MSC3952_MENTIONS: Final = "org.matrix.msc3952.mentions" + # an unspecced field added to to-device messages to identify them uniquely-ish TO_DEVICE_MSGID: Final = "org.matrix.msgid" @@ -290,3 +295,8 @@ class ApprovalNoticeMedium: NONE = "org.matrix.msc3866.none" EMAIL = "org.matrix.msc3866.email" + + +class Direction(enum.Enum): + BACKWARDS = "b" + FORWARDS = "f" diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1f6362aedd..2ce60610ca 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -174,15 +174,29 @@ class Config: @staticmethod def parse_size(value: Union[str, int]) -> int: - if isinstance(value, int): + """Interpret `value` as a number of bytes. + + If an integer is provided it is treated as bytes and is unchanged. + + String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and + mebibytes respectively. No suffix is understood as a plain byte count. + + Raises: + TypeError, if given something other than an integer or a string + ValueError: if given a string not of the form described above. + """ + if type(value) is int: return value - sizes = {"K": 1024, "M": 1024 * 1024} - size = 1 - suffix = value[-1] - if suffix in sizes: - value = value[:-1] - size = sizes[suffix] - return int(value) * size + elif type(value) is str: + sizes = {"K": 1024, "M": 1024 * 1024} + size = 1 + suffix = value[-1] + if suffix in sizes: + value = value[:-1] + size = sizes[suffix] + return int(value) * size + else: + raise TypeError(f"Bad byte size {value!r}") @staticmethod def parse_duration(value: Union[str, int]) -> int: @@ -198,22 +212,36 @@ class Config: Returns: The number of milliseconds in the duration. + + Raises: + TypeError, if given something other than an integer or a string + ValueError: if given a string not of the form described above. """ - if isinstance(value, int): + if type(value) is int: return value - second = 1000 - minute = 60 * second - hour = 60 * minute - day = 24 * hour - week = 7 * day - year = 365 * day - sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year} - size = 1 - suffix = value[-1] - if suffix in sizes: - value = value[:-1] - size = sizes[suffix] - return int(value) * size + elif type(value) is str: + second = 1000 + minute = 60 * second + hour = 60 * minute + day = 24 * hour + week = 7 * day + year = 365 * day + sizes = { + "s": second, + "m": minute, + "h": hour, + "d": day, + "w": week, + "y": year, + } + size = 1 + suffix = value[-1] + if suffix in sizes: + value = value[:-1] + size = sizes[suffix] + return int(value) * size + else: + raise TypeError(f"Bad duration {value!r}") @staticmethod def abspath(file_path: str) -> str: diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 015b2a138e..05f69cb1ba 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -126,7 +126,7 @@ class CacheConfig(Config): cache_config = config.get("caches") or {} self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE) - if not isinstance(self.global_factor, (int, float)): + if type(self.global_factor) not in (int, float): raise ConfigError("caches.global_factor must be a number.") # Load cache factors from the config @@ -151,7 +151,7 @@ class CacheConfig(Config): ) for cache, factor in individual_factors.items(): - if not isinstance(factor, (int, float)): + if type(factor) not in (int, float): raise ConfigError( "caches.per_cache_factors.%s must be a number" % (cache,) ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 2590c88cde..d2d0270ddd 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -168,3 +168,8 @@ class ExperimentalConfig(Config): # MSC3925: do not replace events with their edits self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + + # MSC3952: Intentional mentions + self.msc3952_intentional_mentions = experimental.get( + "msc3952_intentional_mentions", False + ) diff --git a/synapse/config/server.py b/synapse/config/server.py index 80bcfa4080..ecdaa2d9dd 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -904,7 +904,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type")) port = listener.get("port") - if not isinstance(port, int): + if type(port) is not int: raise ConfigError("Listener configuration is lacking a valid 'port' option") tls = listener.get("tls", False) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c4a7b16413..e0be9f88cc 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -875,11 +875,11 @@ def _check_power_levels( "kick", "invite", }: - if not isinstance(v, int): + if type(v) is not int: raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - isinstance(v, int) for v in v.values() + type(v) is int for v in v.values() ): raise SynapseError( 400, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ae57a4df5e..ebf8c7ed83 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -605,10 +605,11 @@ class EventClientSerializer: _PowerLevel = Union[str, int] +PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] def copy_and_fixup_power_levels_contents( - old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]] + old_power_levels: PowerLevelsContent, ) -> Dict[str, Union[int, Dict[str, int]]]: """Copy the content of a power_levels event, unfreezing frozendicts along the way. @@ -647,10 +648,10 @@ def _copy_power_level_value_as_integer( ) -> None: """Set `power_levels[key]` to the integer represented by `old_value`. - :raises TypeError: if `old_value` is not an integer, nor a base-10 string + :raises TypeError: if `old_value` is neither an integer nor a base-10 string representation of an integer. """ - if isinstance(old_value, int): + if type(old_value) is int: power_levels[key] = old_value return @@ -678,7 +679,7 @@ def validate_canonicaljson(value: Any) -> None: * Floats * NaN, Infinity, -Infinity """ - if isinstance(value, int): + if type(value) is int: if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index a6f0104396..fb1737b910 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -139,7 +139,7 @@ class EventValidator: max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if not isinstance(min_lifetime, int): + if type(min_lifetime) is not int: raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -147,7 +147,7 @@ class EventValidator: ) if max_lifetime is not None: - if not isinstance(max_lifetime, int): + if type(max_lifetime) is not int: raise SynapseError( code=400, msg="'max_lifetime' must be an integer", diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 6bd4742140..29fae716f5 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -280,7 +280,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB _strip_unsigned_values(pdu_json) depth = pdu_json["depth"] - if not isinstance(depth, int): + if type(depth) is not int: raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f185b6c1f9..feb32e40e5 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1864,7 +1864,7 @@ class TimestampToEventResponse: ) origin_server_ts = d.get("origin_server_ts") - if not isinstance(origin_server_ts, int): + if type(origin_server_ts) is not int: raise ValueError( "Invalid response: 'origin_server_ts' must be a int but received %r" % origin_server_ts diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3197939a36..c9a6dfd1a4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -62,7 +62,9 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.opentracing import ( + SynapseTags, log_kv, + set_tag, start_active_span_from_edu, tag_args, trace, @@ -678,6 +680,10 @@ class FederationServer(FederationBase): room_id: str, caller_supports_partial_state: bool = False, ) -> Dict[str, Any]: + set_tag( + SynapseTags.SEND_JOIN_RESPONSE_IS_PARTIAL_STATE, + caller_supports_partial_state, + ) await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type] requester=None, key=room_id, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 834006356a..d500b21809 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple from synapse.api.constants import AccountDataTypes from synapse.replication.http.account_data import ( @@ -26,7 +26,7 @@ from synapse.replication.http.account_data import ( ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -322,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): user: UserID, from_key: int, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[JsonDict], int]: diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5bf8e86387..c81ea34758 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -16,7 +16,7 @@ import abc import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set -from synapse.api.constants import Membership +from synapse.api.constants import Direction, Membership from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client @@ -197,7 +197,7 @@ class AdminHandler: # efficient method perhaps but it does guarantee we get everything. while True: events, _ = await self.store.paginate_room_events( - room_id, from_key, to_key, limit=100, direction="f" + room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS ) if not events: break diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 58180ae2fa..5c06073901 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -18,7 +18,6 @@ from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, - Collection, Dict, Iterable, List, @@ -45,6 +44,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.types import ( JsonDict, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -146,7 +146,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( - self, user_id: str, room_ids: Collection[str], from_token: StreamToken + self, user_id: str, room_ids: StrCollection, from_token: StreamToken ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. @@ -551,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") async def notify_device_update( - self, user_id: str, device_ids: Collection[str] + self, user_id: str, device_ids: StrCollection ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index f91dbbecb7..a23a8ce2a1 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,7 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.types import StateMap, get_domain_from_id +from synapse.types import StateMap, StrCollection, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -290,7 +290,7 @@ class EventAuthHandler: async def get_rooms_that_allow_join( self, state_ids: StateMap[str] - ) -> Collection[str]: + ) -> StrCollection: """ Generate a list of rooms in which membership allows access to a room. @@ -331,7 +331,7 @@ class EventAuthHandler: return result - async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool: + async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool: """ Check whether a user is a member of any of the provided rooms. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 233f8c113d..dc1cbf5c3d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,17 +20,7 @@ import itertools import logging from enum import Enum from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union import attr from prometheus_client import Histogram @@ -70,7 +60,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -179,7 +169,7 @@ class FederationHandler: # A dictionary mapping room IDs to (initial destination, other destinations) # tuples. self._partial_state_syncs_maybe_needing_restart: Dict[ - str, Tuple[Optional[str], Collection[str]] + str, Tuple[Optional[str], StrCollection] ] = {} # A lock guarding the partial state flag for rooms. # When the lock is held for a given room, no other concurrent code may @@ -437,7 +427,7 @@ class FederationHandler: ) ) - async def try_backfill(domains: Collection[str]) -> bool: + async def try_backfill(domains: StrCollection) -> bool: # TODO: Should we try multiple of these at a time? # Number of contacted remote homeservers that have denied our backfill @@ -1730,7 +1720,7 @@ class FederationHandler: def _start_partial_state_room_sync( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Starts the background process to resync the state of a partial state room, @@ -1812,7 +1802,7 @@ class FederationHandler: async def _sync_partial_state_room( self, initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, ) -> None: """Background process to resync the state of a partial-state room @@ -1949,9 +1939,9 @@ class FederationHandler: def _prioritise_destinations_for_partial_state_resync( initial_destination: Optional[str], - other_destinations: Collection[str], + other_destinations: StrCollection, room_id: str, -) -> Collection[str]: +) -> StrCollection: """Work out the order in which we should ask servers to resync events. If an `initial_destination` is given, it takes top priority. Otherwise diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 904a721483..e037acbca2 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -80,6 +80,7 @@ from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, + StrCollection, UserID, get_domain_from_id, ) @@ -615,7 +616,7 @@ class FederationEventHandler: @trace async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Collection[str] + self, dest: str, room_id: str, limit: int, extremities: StrCollection ) -> None: """Trigger a backfill request to `dest` for the given `room_id` @@ -1565,7 +1566,7 @@ class FederationEventHandler: @trace @tag_args async def _get_events_and_persist( - self, destination: str, room_id: str, event_ids: Collection[str] + self, destination: str, room_id: str, event_ids: StrCollection ) -> None: """Fetch the given events from a server, and persist them as outliers. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 8c2260ad7d..191529bd8e 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -15,7 +15,13 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple, cast -from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + Direction, + EduTypes, + EventTypes, + Membership, +) from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig @@ -57,7 +63,13 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool + str, + Optional[StreamToken], + Optional[StreamToken], + Direction, + int, + bool, + bool, ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3278a695ed..e688e00575 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -377,7 +377,7 @@ class MessageHandler: """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if not isinstance(expiry_ts, int) or event.is_state(): + if type(expiry_ts) is not int or event.is_state(): return # _schedule_expiry_for_event won't actually schedule anything if there's already @@ -1939,7 +1939,9 @@ class EventCreationHandler: if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. - run_in_background(self._bump_active_time, requester.user) + run_as_background_process( + "bump_presence_active_time", self._bump_active_time, requester.user + ) async def _notify() -> None: try: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8c8ff18a1a..ceefa16b49 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set import attr from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events.utils import SerializeEventConfig @@ -28,7 +28,7 @@ from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string @@ -391,7 +391,7 @@ class PaginationHandler: """ return self._delete_by_id.get(delete_id) - def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]: + def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]: """Get all active delete ids by room Args: @@ -448,7 +448,7 @@ class PaginationHandler: if pagin_config.from_token: from_token = pagin_config.from_token - elif pagin_config.direction == "f": + elif pagin_config.direction == Direction.FORWARDS: from_token = ( await self.hs.get_event_sources().get_start_token_for_pagination( room_id @@ -476,7 +476,7 @@ class PaginationHandler: room_id, requester, allow_departed_users=True ) - if pagin_config.direction == "b": + if pagin_config.direction == Direction.BACKWARDS: # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 43e4e7b1b4..87af31aa27 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -64,7 +64,13 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore from synapse.streams import EventSource -from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + StrCollection, + StreamKeyType, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -320,7 +326,7 @@ class BasePresenceHandler(abc.ABC): for destination, host_states in hosts_to_states.items(): self._federation.send_presence_to_destinations(host_states, [destination]) - async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None: + async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ Adds to the list of users who should receive a full snapshot of presence upon their next sync. Note that this only works for local users. @@ -1601,7 +1607,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # Having a default limit doesn't match the EventSource API, but some # callers do not provide it. It is unused in this class. limit: int = 0, - room_ids: Optional[Collection[str]] = None, + room_ids: Optional[StrCollection] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, @@ -1688,7 +1694,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # The set of users that we're interested in and that have had a presence update. # We'll actually pull the presence updates for these users at the end. - interested_and_updated_users: Collection[str] + interested_and_updated_users: StrCollection if from_key is not None: # First get all users that have had a presence update @@ -2120,7 +2126,7 @@ class PresenceFederationQueue: # stream_id, destinations, user_ids)`. We don't store the full states # for efficiency, and remote workers will already have the full states # cached. - self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = [] + self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = [] self._next_id = 1 @@ -2142,7 +2148,7 @@ class PresenceFederationQueue: self._queue = self._queue[index:] def send_presence_to_destinations( - self, states: Collection[UserPresenceState], destinations: Collection[str] + self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index e96f9999a8..0fb15391e0 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O import attr -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -413,7 +413,11 @@ class RelationsHandler: # Attempt to find another event to use as the latest event. potential_events, _ = await self._main_store.get_relations_for_event( - event_id, event, room_id, RelationTypes.THREAD, direction="f" + event_id, + event, + room_id, + RelationTypes.THREAD, + direction=Direction.FORWARDS, ) # Filter out ignored users. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 572c7b4db3..60a6d9cf3c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,16 +20,7 @@ import random import string from collections import OrderedDict from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Collection, - Dict, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple import attr from typing_extensions import TypedDict @@ -72,6 +63,7 @@ from synapse.types import ( RoomID, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1644,7 +1636,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): user: UserID, from_key: RoomStreamToken, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index c6b869c6f4..4472019fbc 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -36,7 +36,7 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase -from synapse.types import JsonDict, Requester +from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -870,7 +870,7 @@ class _RoomQueueEntry: # The room ID of this entry. room_id: str # The server to query if the room is not known locally. - via: Sequence[str] + via: StrCollection # The minimum number of hops necessary to get to this room (compared to the # originally requested room). depth: int = 0 diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 40f4635c4e..9bbf83047d 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,7 +14,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from unpaddedbase64 import decode_base64, encode_base64 @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -418,7 +418,7 @@ class SearchHandler: async def _search_by_rank( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, @@ -491,7 +491,7 @@ class SearchHandler: async def _search_by_recent( self, user: UserID, - room_ids: Collection[str], + room_ids: StrCollection, search_term: str, keys: Iterable[str], search_filter: Filter, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 44e70fc4b8..4a27c0f051 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -20,7 +20,6 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -47,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import ( JsonDict, + StrCollection, UserID, contains_invalid_mxid_characters, create_requester, @@ -141,7 +141,8 @@ class UserAttributes: confirm_localpart: bool = False display_name: Optional[str] = None picture: Optional[str] = None - emails: Collection[str] = attr.Factory(list) + # mypy thinks these are incompatible for some reason. + emails: StrCollection = attr.Factory(list) # type: ignore[assignment] @attr.s(slots=True, auto_attribs=True) @@ -159,7 +160,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name: Optional[str] - emails: Collection[str] + emails: StrCollection # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -174,7 +175,7 @@ class UsernameMappingSession: # choices made by the user chosen_localpart: Optional[str] = None use_display_name: bool = True - emails_to_use: Collection[str] = () + emails_to_use: StrCollection = () terms_accepted_version: Optional[str] = None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5ebd3ea855..5235e29460 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, AbstractSet, Any, - Collection, Dict, FrozenSet, List, @@ -62,6 +61,7 @@ from synapse.types import ( Requester, RoomStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1179,7 +1179,7 @@ class SyncHandler: async def _find_missing_partial_state_memberships( self, room_id: str, - members_to_fetch: Collection[str], + members_to_fetch: StrCollection, events_with_membership_auth: Mapping[str, EventBase], found_state_ids: StateMap[str], ) -> StateMap[str]: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index a705af8356..8ef9a0dda8 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -322,6 +322,11 @@ class SynapseTags: # The name of the external cache CACHE_NAME = "cache.name" + # Boolean. Present on /v2/send_join requests, omitted from all others. + # True iff partial state was requested and we provided (or intended to provide) + # partial state in the response. + SEND_JOIN_RESPONSE_IS_PARTIAL_STATE = "send_join.partial_state_response" + # Used to tag function arguments # # Tag a named arg. The name of the argument should be appended to this prefix. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6153a48257..d22dd19d38 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1158,7 +1158,7 @@ class ModuleApi: # Send to remote destinations. destination = UserID.from_string(user).domain presence_handler.get_federation_queue().send_presence_to_destinations( - presence_events, destination + presence_events, [destination] ) def looping_background_call( diff --git a/synapse/notifier.py b/synapse/notifier.py index 2b0e52f23c..a8832a3f8e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -46,6 +46,7 @@ from synapse.types import ( JsonDict, PersistedEventPosition, RoomStreamToken, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -716,7 +717,7 @@ class Notifier: async def _get_room_ids( self, user: UserID, explicit_room_id: Optional[str] - ) -> Tuple[Collection[str], bool]: + ) -> Tuple[StrCollection, bool]: joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f27ba64d53..88cfc05d05 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -22,13 +22,20 @@ from typing import ( List, Mapping, Optional, + Set, Tuple, Union, ) from prometheus_client import Counter -from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes +from synapse.api.constants import ( + MAIN_TIMELINE, + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.room_versions import PushRuleRoomFlag, RoomVersion from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event @@ -62,6 +69,9 @@ STATE_EVENT_TYPES_TO_MARK_UNREAD = { } +SENTINEL = object() + + def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: # Exclude rejected and soft-failed events. if context.rejected or event.internal_metadata.is_soft_failed(): @@ -336,14 +346,40 @@ class BulkPushRuleEvaluator: related_events = await self._related_events(event) # It's possible that old room versions have non-integer power levels (floats or - # strings). Workaround this by explicitly converting to int. + # strings; even the occasional `null`). For old rooms, we interpret these as if + # they were integers. Do this here for the `@room` power level threshold. + # Note that this is done automatically for the sender's power level by + # _get_power_levels_and_sender_level in its call to get_user_power_level + # (even for room V10.) notification_levels = power_levels.get("notifications", {}) if not event.room_version.msc3667_int_only_power_levels: - for user_id, level in notification_levels.items(): - notification_levels[user_id] = int(level) + keys = list(notification_levels.keys()) + for key in keys: + level = notification_levels.get(key, SENTINEL) + if level is not SENTINEL and type(level) is not int: + try: + notification_levels[key] = int(level) + except (TypeError, ValueError): + del notification_levels[key] + + # Pull out any user and room mentions. + mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) + user_mentions: Set[str] = set() + room_mention = False + if isinstance(mentions, dict): + # Remove out any non-string items and convert to a set. + user_mentions_raw = mentions.get("user_ids") + if isinstance(user_mentions_raw, list): + user_mentions = set( + filter(lambda item: isinstance(item, str), user_mentions_raw) + ) + # Room mention is only true if the value is exactly true. + room_mention = mentions.get("room") is True evaluator = PushRuleEvaluator( _flatten_dict(event, room_version=event.room_version), + user_mentions, + room_mention, room_member_count, sender_power_level, notification_levels, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index fb73886df0..79f22a59f1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -152,7 +152,7 @@ class PurgeHistoryRestServlet(RestServlet): logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: ts = body["purge_up_to_ts"] - if not isinstance(ts, int): + if type(ts) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "purge_up_to_ts must be an int", diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index af606e9252..95e751288b 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet): else: # Get length of token to generate (default is 16) length = body.get("length", 16) - if not isinstance(length, int): + if type(length) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "length must be an integer", @@ -163,8 +163,7 @@ class NewRegistrationTokenRestServlet(RestServlet): uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None - or (isinstance(uses_allowed, int) and uses_allowed >= 0) + uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -173,13 +172,13 @@ class NewRegistrationTokenRestServlet(RestServlet): ) expiry_time = body.get("expiry_time", None) - if not isinstance(expiry_time, (int, type(None))): + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", @@ -284,7 +283,7 @@ class RegistrationTokenRestServlet(RestServlet): uses_allowed = body["uses_allowed"] if not ( uses_allowed is None - or (isinstance(uses_allowed, int) and uses_allowed >= 0) + or (type(uses_allowed) is int and uses_allowed >= 0) ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -295,13 +294,13 @@ class RegistrationTokenRestServlet(RestServlet): if "expiry_time" in body: expiry_time = body["expiry_time"] - if not isinstance(expiry_time, (int, type(None))): + if type(expiry_time) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + if type(expiry_time) is int and expiry_time < self.clock.time_msec(): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 6e0c44be2a..0841b89c1a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -973,7 +973,7 @@ class UserTokenRestServlet(RestServlet): body = parse_json_object_from_request(request, allow_empty_body=True) valid_until_ms = body.get("valid_until_ms") - if valid_until_ms and not isinstance(valid_until_ms, int): + if type(valid_until_ms) not in (int, type(None)): raise SynapseError( HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int" ) @@ -1125,14 +1125,14 @@ class RateLimitRestServlet(RestServlet): messages_per_second = body.get("messages_per_second", 0) burst_count = body.get("burst_count", 0) - if not isinstance(messages_per_second, int) or messages_per_second < 0: + if type(messages_per_second) is not int or messages_per_second < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) - if not isinstance(burst_count, int) or burst_count < 0: + if type(burst_count) is not int or burst_count < 0: raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 8191b4e32c..ad5c10c99d 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union from synapse.api.errors import ( NotFoundError, @@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet): raise UnrecognizedRequestError() -def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: +def _rule_spec_from_path(path: List[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index 6e962a4532..e2b410cf32 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -54,7 +54,7 @@ class ReportEventRestServlet(RestServlet): "Param 'reason' must be a string", Codes.BAD_JSON, ) - if not isinstance(body.get("score", 0), int): + if type(body.get("score", 0)) is not int: raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 61375651bc..3f40f1874a 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple from typing_extensions import ParamSpec +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.web.server import Request @@ -90,7 +91,7 @@ class HttpTransactionCache: fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], *args: P.args, **kwargs: P.kwargs, - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> "Deferred[Tuple[int, JsonDict]]": """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index a3738a6250..7592aa5d47 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -200,7 +200,7 @@ class OEmbedProvider: calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) - if val is not None and isinstance(val, int): + if type(val) is int: open_graph_response[f"og:video:{size}"] = val elif oembed_type == "link": diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index a48a4de92a..9480cc5763 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -77,7 +77,7 @@ class Thumbnailer: image_exif = self.image._getexif() # type: ignore if image_exif is not None: image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert isinstance(image_orientation, int) + assert type(image_orientation) is int self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) except Exception as e: # A lot of parsing errors can happen when parsing EXIF diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0f097a2927..1536937b67 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1651,7 +1651,7 @@ class PersistEventsStore: if self._ephemeral_messages_enabled: # If there's an expiry timestamp on the event, store it. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if isinstance(expiry_ts, int) and not event.is_state(): + if type(expiry_ts) is int and not event.is_state(): self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) # Insert into the room_memberships table. @@ -2133,10 +2133,10 @@ class PersistEventsStore: ): if ( "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), int) + and type(event.content["min_lifetime"]) is not int ) or ( "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), int) + and type(event.content["max_lifetime"]) is not int ): # Ignore the event if one of the value isn't an integer. return diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 14ca167b34..466a1145b7 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -89,6 +89,7 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, + msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions, ) return filtered_rules diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 84f844b79e..0018d6f7ab 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,7 +30,7 @@ from typing import ( import attr -from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -40,9 +40,13 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.databases.main.stream import generate_pagination_where_clause +from synapse.storage.databases.main.stream import ( + generate_next_token, + generate_pagination_bounds, + generate_pagination_where_clause, +) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken +from synapse.types import JsonDict, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -164,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Optional[str] = None, event_type: Optional[str] = None, limit: int = 5, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: @@ -177,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. @@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore): where_clause.append("type = ?") where_args.append(event_type) + order, from_bound, to_bound = generate_pagination_bounds( + direction, + from_token.room_key if from_token else None, + to_token.room_key if to_token else None, + ) + pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.room_key.as_historical_tuple() - if from_token - else None, - to_token=to_token.room_key.as_historical_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) if pagination_clause: where_clause.append(pagination_clause) - if direction == "b": - order = "DESC" - else: - order = "ASC" - sql = """ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations @@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore): topo_orderings = topo_orderings[:limit] stream_orderings = stream_orderings[:limit] - topo = topo_orderings[-1] - token = stream_orderings[-1] - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_key = RoomStreamToken(topo, token) + next_key = generate_next_token( + direction, topo_orderings[-1], stream_orderings[-1] + ) if from_token: next_token = from_token.copy_and_replace( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index d28fc65df9..818c46182e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -55,6 +55,7 @@ from typing_extensions import Literal from twisted.internet import defer +from synapse.api.constants import Direction from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" - # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: @@ -104,7 +104,7 @@ class _EventsAround: def generate_pagination_where_clause( - direction: str, + direction: Direction, column_names: Tuple[str, str], from_token: Optional[Tuple[Optional[int], int]], to_token: Optional[Tuple[Optional[int], int]], @@ -130,27 +130,26 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction: Whether we're paginating backwards("b") or forwards ("f"). + direction: Whether we're paginating backwards or forwards. column_names: The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. from_token: The start point for the pagination. This is an exclusive - minimum bound if direction is "f", and an inclusive maximum bound if - direction is "b". + minimum bound if direction is forwards, and an inclusive maximum bound if + direction is backwards. to_token: The endpoint point for the pagination. This is an inclusive - maximum bound if direction is "f", and an exclusive minimum bound if - direction is "b". + maximum bound if direction is forwards, and an exclusive minimum bound if + direction is backwards. engine: The database engine to generate the clauses for Returns: The sql expression """ - assert direction in ("b", "f") where_clause = [] if from_token: where_clause.append( _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", + bound=">=" if direction == Direction.BACKWARDS else "<", column_names=column_names, values=from_token, engine=engine, @@ -160,7 +159,7 @@ def generate_pagination_where_clause( if to_token: where_clause.append( _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", + bound="<" if direction == Direction.BACKWARDS else ">=", column_names=column_names, values=to_token, engine=engine, @@ -170,6 +169,104 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) +def generate_pagination_bounds( + direction: Direction, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], +) -> Tuple[ + str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]] +]: + """ + Generate a start and end point for this page of events. + + Args: + direction: Whether pagination is going forwards or backwards. + from_token: The token to start pagination at, or None to start at the first value. + to_token: The token to end pagination at, or None to not limit the end point. + + Returns: + A three tuple of: + + ASC or DESC for sorting of the query. + + The starting position as a tuple of ints representing + (topological position, stream position) or None if no from_token was + provided. The topological position may be None for live tokens. + + The end position in the same format as the starting position, or None + if no to_token was provided. + """ + + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + from_bound: Optional[Tuple[Optional[int], int]] = None + if from_token: + if from_token.topological is not None: + from_bound = from_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound: Optional[Tuple[Optional[int], int]] = None + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == Direction.BACKWARDS: + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + + return order, from_bound, to_bound + + +def generate_next_token( + direction: Direction, last_topo_ordering: int, last_stream_ordering: int +) -> RoomStreamToken: + """ + Generate the next room stream token based on the currently returned data. + + Args: + direction: Whether pagination is going forwards or backwards. + last_topo_ordering: The last topological ordering being returned. + last_stream_ordering: The last stream ordering being returned. + + Returns: + A new RoomStreamToken to return to the client. + """ + if direction == Direction.BACKWARDS: + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + last_stream_ordering -= 1 + return RoomStreamToken(last_topo_ordering, last_stream_ordering) + + def _make_generic_sql_bound( bound: str, column_names: Tuple[str, str], @@ -1103,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, before_token, - direction="b", + direction=Direction.BACKWARDS, limit=before_limit, event_filter=event_filter, ) @@ -1113,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, after_token, - direction="f", + direction=Direction.FORWARDS, limit=after_limit, event_filter=event_filter, ) @@ -1276,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: @@ -1287,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. @@ -1300,47 +1397,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] - if direction == "b": - order = "DESC" - else: - order = "ASC" - - # The bounds for the stream tokens are complicated by the fact - # that we need to handle the instance_map part of the tokens. We do this - # by fetching all events between the min stream token and the maximum - # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and - # then filtering the results. - if from_token.topological is not None: - from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() - elif direction == "b": - from_bound = ( - None, - from_token.get_max_stream_pos(), - ) - else: - from_bound = ( - None, - from_token.stream, - ) - to_bound: Optional[Tuple[Optional[int], int]] = None - if to_token: - if to_token.topological is not None: - to_bound = to_token.as_historical_tuple() - elif direction == "b": - to_bound = ( - None, - to_token.stream, - ) - else: - to_bound = ( - None, - to_token.get_max_stream_pos(), - ) + order, from_bound, to_bound = generate_pagination_bounds( + direction, from_token, to_token + ) bounds = generate_pagination_where_clause( direction=direction, @@ -1427,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - lower_token=to_token if direction == "b" else from_token, - upper_token=from_token if direction == "b" else to_token, + lower_token=to_token + if direction == Direction.BACKWARDS + else from_token, + upper_token=from_token + if direction == Direction.BACKWARDS + else to_token, instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, @@ -1436,16 +1501,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ][:limit] if rows: - topo = rows[-1].topological_ordering - token = rows[-1].stream_ordering - if direction == "b": - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - token -= 1 - next_token = RoomStreamToken(topo, token) + assert rows[-1].topological_ordering is not None + next_token = generate_next_token( + direction, rows[-1].topological_ordering, rows[-1].stream_ordering + ) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token @@ -1458,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_key: RoomStreamToken, to_key: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: @@ -1468,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index bc9ca3a53c..0363cdc038 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -133,8 +133,9 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM This is not provided by DBAPI2, and so needs engine-specific support. - Some database engines may automatically COMMIT the ongoing transaction both - before and after executing the script. + Any ongoing transaction is committed before executing the script in its own + transaction. The script transaction is left open and it is the responsibility of + the caller to commit it. """ ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index f9f562ea45..b350f57ccb 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -220,5 +220,9 @@ class PostgresEngine( """Execute a chunk of SQL containing multiple semicolon-delimited statements. Psycopg2 seems happy to do this in DBAPI2's `execute()` function. + + For consistency with SQLite, any ongoing transaction is committed before + executing the script in its own transaction. The script transaction is + left open and it is the responsibility of the caller to commit it. """ - cursor.execute(script) + cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}") diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 2f7df85ce4..28751e89a5 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -135,14 +135,16 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): > than one statement with it, it will raise a Warning. Use executescript() if > you want to execute multiple SQL statements with one call. - The script is wrapped in transaction control statemnets, since the docs for + The script is prefixed with a `BEGIN TRANSACTION`, since the docs for `executescript` warn: > If there is a pending transaction, an implicit COMMIT statement is executed > first. No other implicit transaction control is performed; any transaction > control must be added to sql_script. """ - cursor.executescript(f"BEGIN TRANSACTION;\n{script}\nCOMMIT;") + # The implementation of `executescript` can be found at + # https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035. + cursor.executescript(f"BEGIN TRANSACTION; {script}") # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 2dcd43d0a2..c6c8a0315c 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Generic, List, Optional, Tuple, TypeVar +from typing import Generic, List, Optional, Tuple, TypeVar -from synapse.types import UserID +from synapse.types import StrCollection, UserID # The key, this is either a stream token or int. K = TypeVar("K") @@ -28,7 +28,7 @@ class EventSource(Generic[K, R]): user: UserID, from_key: K, limit: int, - room_ids: Collection[str], + room_ids: StrCollection, is_guest: bool, explicit_room_id: Optional[str] = None, ) -> Tuple[List[R], K]: diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 6df2de919c..5cb7875181 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -16,6 +16,7 @@ from typing import Optional import attr +from synapse.api.constants import Direction from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -34,7 +35,7 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] - direction: str + direction: Direction limit: int @classmethod @@ -45,9 +46,13 @@ class PaginationConfig: default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": - direction = parse_string( - request, "dir", default=default_dir, allowed_values=["f", "b"] + direction_str = parse_string( + request, + "dir", + default=default_dir, + allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value], ) + direction = Direction(direction_str) from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index b703e4472e..a9893def74 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -16,6 +16,8 @@ from unittest.mock import Mock import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EduTypes from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router from synapse.federation.units import Transaction @@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, presence, room +from synapse.server import HomeServer from synapse.types import JsonDict, StreamToken, create_requester +from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config from tests.test_utils import simple_async_mock -from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config +from tests.unittest import FederatingHomeserverTestCase, override_config @attr.s @@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule: } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -96,9 +103,7 @@ class PresenceRouterTestModule: } return users_to_state - async def get_interested_users( - self, user_id: str - ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: if user_id in self._config.users_who_should_receive_all_presence: return PresenceRouter.ALL_USERS @@ -118,9 +123,14 @@ class PresenceRouterTestModule: # Initialise a typed config object config = PresenceRouterTestConfig() - config.users_who_should_receive_all_presence = config_dict.get( + users_who_should_receive_all_presence = config_dict.get( "users_who_should_receive_all_presence" ) + assert isinstance(users_who_should_receive_all_presence, list) + + config.users_who_should_receive_all_presence = ( + users_who_should_receive_all_presence + ) return config @@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) fed_transport_client.send_transaction = simple_async_mock({}) @@ -153,7 +163,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): return hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.sync_handler = self.hs.get_sync_handler() self.module_api = homeserver.get_module_api() @@ -176,7 +188,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } ) - def test_receiving_all_presence_legacy(self): + def test_receiving_all_presence_legacy(self) -> None: self.receiving_all_presence_test_body() @override_config( @@ -193,10 +205,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ], } ) - def test_receiving_all_presence(self): + def test_receiving_all_presence(self) -> None: self.receiving_all_presence_test_body() - def receiving_all_presence_test_body(self): + def receiving_all_presence_test_body(self) -> None: """Test that a user that does not share a room with another other can receive presence for them, due to presence routing. """ @@ -302,7 +314,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): }, } ) - def test_send_local_online_presence_to_with_module_legacy(self): + def test_send_local_online_presence_to_with_module_legacy(self) -> None: self.send_local_online_presence_to_with_module_test_body() @override_config( @@ -321,10 +333,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ], } ) - def test_send_local_online_presence_to_with_module(self): + def test_send_local_online_presence_to_with_module(self) -> None: self.send_local_online_presence_to_with_module_test_body() - def send_local_online_presence_to_with_module_test_body(self): + def send_local_online_presence_to_with_module_test_body(self) -> None: """Tests that send_local_presence_to_users sends local online presence to a set of specified local and remote users, with a custom PresenceRouter module enabled. """ @@ -447,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): continue # EDUs can contain multiple presence updates - for presence_update in edu["content"]["push"]: + for presence_edu in edu["content"]["push"]: # Check for presence updates that contain the user IDs we're after - found_users.add(presence_update["user_id"]) + found_users.add(presence_edu["user_id"]) # Ensure that no offline states are being sent out - self.assertNotEqual(presence_update["presence"], "offline") + self.assertNotEqual(presence_edu["presence"], "offline") self.assertEqual(found_users, expected_users) def send_presence_update( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, access_token: str, presence_state: str, @@ -479,7 +491,7 @@ def send_presence_update( def sync_presence( - testcase: TestCase, + testcase: FederatingHomeserverTestCase, user_id: str, since_token: Optional[StreamToken] = None, ) -> Tuple[List[UserPresenceState], StreamToken]: @@ -500,7 +512,7 @@ def sync_presence( requester = create_requester(user_id) sync_config = generate_sync_config(requester.user.to_string()) sync_result = testcase.get_success( - testcase.sync_handler.wait_for_sync_for_user( + testcase.hs.get_sync_handler().wait_for_sync_for_user( requester, sync_config, since_token ) ) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 8ddce83b83..6687c28e8f 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils.event_injection import create_event @@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -35,7 +40,7 @@ class TestEventContext(unittest.HomeserverTestCase): self.user_tok = self.login("u1", "pass") self.room_id = self.helper.create_room_as(tok=self.user_tok) - def test_serialize_deserialize_msg(self): + def test_serialize_deserialize_msg(self) -> None: """Test that an EventContext for a message event is the same after serialize/deserialize. """ @@ -51,7 +56,7 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_no_prev(self): + def test_serialize_deserialize_state_no_prev(self) -> None: """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ @@ -67,7 +72,7 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def test_serialize_deserialize_state_prev(self): + def test_serialize_deserialize_state_prev(self) -> None: """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ @@ -84,7 +89,9 @@ class TestEventContext(unittest.HomeserverTestCase): self._check_serialize_deserialize(event, context) - def _check_serialize_deserialize(self, event, context): + def _check_serialize_deserialize( + self, event: EventBase, context: EventContext + ) -> None: serialized = self.get_success(context.serialize(event, self.store)) d_context = EventContext.deserialize(self._storage_controllers, serialized) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index a79256846f..ff7b349d75 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -13,21 +13,24 @@ # limitations under the License. import unittest as stdlib_unittest +from typing import Any, List, Mapping, Optional from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import ( + PowerLevelsContent, SerializeEventConfig, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, serialize_event, ) +from synapse.types import JsonDict from synapse.util.frozenutils import freeze -def MockEvent(**kwargs): +def MockEvent(**kwargs: Any) -> EventBase: if "event_id" not in kwargs: kwargs["event_id"] = "fake_event_id" if "type" not in kwargs: @@ -60,7 +63,7 @@ class TestMaybeUpsertEventField(stdlib_unittest.TestCase): class PruneEventTestCase(stdlib_unittest.TestCase): - def run_test(self, evdict, matchdict, **kwargs): + def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None: """ Asserts that a new event constructed with `evdict` will look like `matchdict` when it is redacted. @@ -74,7 +77,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict ) - def test_minimal(self): + def test_minimal(self) -> None: self.run_test( {"type": "A", "event_id": "$test:domain"}, { @@ -86,7 +89,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_basic_keys(self): + def test_basic_keys(self) -> None: """Ensure that the keys that should be untouched are kept.""" # Note that some of the values below don't really make sense, but the # pruning of events doesn't worry about the values of any fields (with @@ -138,7 +141,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_unsigned(self): + def test_unsigned(self) -> None: """Ensure that unsigned properties get stripped (except age_ts and replaces_state).""" self.run_test( { @@ -159,7 +162,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_content(self): + def test_content(self) -> None: """The content dictionary should be stripped in most cases.""" self.run_test( {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, @@ -194,7 +197,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - def test_create(self): + def test_create(self) -> None: """Create events are partially redacted until MSC2176.""" self.run_test( { @@ -223,7 +226,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_power_levels(self): + def test_power_levels(self) -> None: """Power level events keep a variety of content keys.""" self.run_test( { @@ -273,7 +276,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_alias_event(self): + def test_alias_event(self) -> None: """Alias events have special behavior up through room version 6.""" self.run_test( { @@ -302,7 +305,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.V6, ) - def test_redacts(self): + def test_redacts(self) -> None: """Redaction events have no special behaviour until MSC2174/MSC2176.""" self.run_test( @@ -328,7 +331,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.MSC2176, ) - def test_join_rules(self): + def test_join_rules(self) -> None: """Join rules events have changed behavior starting with MSC3083.""" self.run_test( { @@ -371,7 +374,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): room_version=RoomVersions.V8, ) - def test_member(self): + def test_member(self) -> None: """Member events have changed behavior starting with MSC3375.""" self.run_test( { @@ -417,12 +420,12 @@ class PruneEventTestCase(stdlib_unittest.TestCase): class SerializeEventTestCase(stdlib_unittest.TestCase): - def serialize(self, ev, fields): + def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict: return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) ) - def test_event_fields_works_with_keys(self): + def test_event_fields_works_with_keys(self) -> None: self.assertEqual( self.serialize( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"] @@ -430,7 +433,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"room_id": "!foo:bar"}, ) - def test_event_fields_works_with_nested_keys(self): + def test_event_fields_works_with_nested_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -443,7 +446,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"body": "A message"}}, ) - def test_event_fields_works_with_dot_keys(self): + def test_event_fields_works_with_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -456,7 +459,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"key.with.dots": {}}}, ) - def test_event_fields_works_with_nested_dot_keys(self): + def test_event_fields_works_with_nested_dot_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -472,7 +475,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) - def test_event_fields_nops_with_unknown_keys(self): + def test_event_fields_nops_with_unknown_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -485,7 +488,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {"content": {"foo": "bar"}}, ) - def test_event_fields_nops_with_non_dict_keys(self): + def test_event_fields_nops_with_non_dict_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -498,7 +501,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {}, ) - def test_event_fields_nops_with_array_keys(self): + def test_event_fields_nops_with_array_keys(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -511,7 +514,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): {}, ) - def test_event_fields_all_fields_if_empty(self): + def test_event_fields_all_fields_if_empty(self) -> None: self.assertEqual( self.serialize( MockEvent( @@ -531,16 +534,16 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): }, ) - def test_event_fields_fail_if_fields_not_str(self): + def test_event_fields_fail_if_fields_not_str(self) -> None: with self.assertRaises(TypeError): self.serialize( - MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] + MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item] ) class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: - self.test_content = { + self.test_content: PowerLevelsContent = { "ban": 50, "events": {"m.room.name": 100, "m.room.power_levels": 100}, "events_default": 0, @@ -553,10 +556,11 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): "users_default": 0, } - def _test(self, input): + def _test(self, input: PowerLevelsContent) -> None: a = copy_and_fixup_power_levels_contents(input) self.assertEqual(a["ban"], 50) + assert isinstance(a["events"], Mapping) self.assertEqual(a["events"]["m.room.name"], 100) # make sure that changing the copy changes the copy and not the orig @@ -564,18 +568,19 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): a["events"]["m.room.power_levels"] = 20 self.assertEqual(input["ban"], 50) + assert isinstance(input["events"], Mapping) self.assertEqual(input["events"]["m.room.power_levels"], 100) - def test_unfrozen(self): + def test_unfrozen(self) -> None: self._test(self.test_content) - def test_frozen(self): + def test_frozen(self) -> None: input = freeze(self.test_content) self._test(input) - def test_stringy_integers(self): + def test_stringy_integers(self) -> None: """String representations of decimal integers are converted to integers.""" - input = { + input: PowerLevelsContent = { "a": "100", "b": { "foo": 99, @@ -603,9 +608,9 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def test_invalid_types_raise_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[arg-type] - copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[arg-type] + copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]}) # type: ignore[dict-item] + copy_and_fixup_power_levels_contents({"a": None}) # type: ignore[dict-item] def test_invalid_nesting_raises_type_error(self) -> None: with self.assertRaises(TypeError): - copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) + copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item] diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index d21c11b716..ff589c0b6c 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -23,10 +23,10 @@ from synapse.server import HomeServer from synapse.types import RoomAlias from tests.test_utils import event_injection -from tests.unittest import FederatingHomeserverTestCase, TestCase +from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase -class KnockingStrippedStateEventHelperMixin(TestCase): +class KnockingStrippedStateEventHelperMixin(HomeserverTestCase): def send_example_state_events_to_room( self, hs: "HomeServer", @@ -49,7 +49,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase): # To set a canonical alias, we'll need to point an alias at the room first. canonical_alias = "#fancy_alias:test" self.get_success( - self.store.create_room_alias_association( + self.hs.get_datastores().main.create_room_alias_association( RoomAlias.from_string(canonical_alias), room_id, ["test"] ) ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index efbb5a8dbb..1fe9563c98 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,21 +14,22 @@ import json -from typing import Dict +from typing import Dict, List, Set from unittest.mock import ANY, Mock, call -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer +from synapse.handlers.typing import TypingWriterHandler from synapse.server import HomeServer from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from tests import unittest +from tests.server import ThreadedMemoryReactorClock from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -62,7 +63,11 @@ def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes: class TypingNotificationsTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + def make_homeserver( + self, + reactor: ThreadedMemoryReactorClock, + clock: Clock, + ) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) @@ -75,8 +80,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) + self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( - notifier=Mock(), + notifier=self.mock_hs_notifier, federation_http_client=mock_federation_client, keyring=mock_keyring, replication_streams={}, @@ -90,32 +96,38 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): return d def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - mock_notifier = hs.get_notifier() - self.on_new_event = mock_notifier.on_new_event + self.on_new_event = self.mock_hs_notifier.on_new_event - self.handler = hs.get_typing_handler() + # hs.get_typing_handler will return a TypingWriterHandler when calling it + # from the main process, and a FollowerTypingHandler on workers. + # We rely on methods only available on the former, so assert we have the + # correct type here. We have to assign self.handler after the assert, + # otherwise mypy will treat it as a FollowerTypingHandler + handler = hs.get_typing_handler() + assert isinstance(handler, TypingWriterHandler) + self.handler = handler self.event_source = hs.get_event_sources().sources.typing self.datastore = hs.get_datastores().main + self.datastore.get_destination_retry_timings = Mock( return_value=make_awaitable(None) ) - self.datastore.get_device_updates_by_remote = Mock( + self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] return_value=make_awaitable((0, [])) ) - self.datastore.get_destination_last_successful_stream_ordering = Mock( + self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] return_value=make_awaitable(None) ) - def get_received_txn_response(*args): - return defer.succeed(None) - - self.datastore.get_received_txn_response = get_received_txn_response + self.datastore.get_received_txn_response = Mock( # type: ignore[assignment] + return_value=make_awaitable(None) + ) - self.room_members = [] + self.room_members: List[UserID] = [] async def check_user_in_room(room_id: str, requester: Requester) -> None: if requester.user.to_string() not in [ @@ -124,47 +136,54 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): raise AuthError(401, "User is not in the room") return None - hs.get_auth().check_user_in_room = check_user_in_room + hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment] + side_effect=check_user_in_room + ) async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().is_host_in_room = check_host_in_room + hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment] + side_effect=check_host_in_room + ) - async def get_current_hosts_in_room(room_id: str): + async def get_current_hosts_in_room(room_id: str) -> Set[str]: return {member.domain for member in self.room_members} - hs.get_storage_controllers().state.get_current_hosts_in_room = ( - get_current_hosts_in_room + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] + side_effect=get_current_hosts_in_room ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( - get_current_hosts_in_room + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment] + side_effect=get_current_hosts_in_room ) - async def get_users_in_room(room_id: str): + async def get_users_in_room(room_id: str) -> Set[str]: return {str(u) for u in self.room_members} - self.datastore.get_users_in_room = get_users_in_room + self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) - self.datastore.get_user_directory_stream_pos = Mock( + self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] side_effect=( - # we deliberately return a non-None stream pos to avoid doing an initial_spam + # we deliberately return a non-None stream pos to avoid + # doing an initial_sync lambda: make_awaitable(1) ) ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] - self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = ( - lambda *args, **kargs: make_awaitable(([], 0)) + self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] + side_effect=lambda: 0 + ) + self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(([], 0)) ) - self.datastore.delete_device_msgs_for_remote = ( - lambda *args, **kargs: make_awaitable(None) + self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kargs: make_awaitable(None) ) - self.datastore.set_received_txn_response = ( - lambda *args, **kwargs: make_awaitable(None) + self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] + side_effect=lambda *args, **kwargs: make_awaitable(None) ) def test_started_typing_local(self) -> None: @@ -186,7 +205,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -257,7 +276,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -298,7 +317,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[OTHER_ROOM_ID], is_guest=False, ) @@ -351,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( self.event_source.get_new_events( - user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False + user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False ) ) self.assertEqual( @@ -387,7 +406,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) @@ -412,7 +431,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=1, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) @@ -447,7 +466,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.event_source.get_new_events( user=U_APPLE, from_key=0, - limit=None, + limit=0, room_ids=[ROOM_ID], is_guest=False, ) diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py index 1acf5666a8..1c5de95a80 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. import logging +from tests.unittest import TestCase -class LoggerCleanupMixin: - def get_logger(self, handler): + +class LoggerCleanupMixin(TestCase): + def get_logger(self, handler: logging.Handler) -> logging.Logger: """ Attach a handler to a logger and add clean-ups to remove revert this. """ diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index 0917e478a5..e28ba84cc2 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -153,7 +153,7 @@ class LogContextScopeManagerTestCase(TestCase): scopes = [] - async def task(i: int): + async def task(i: int) -> None: scope = start_active_span( f"task{i}", tracer=self._tracer, @@ -165,7 +165,7 @@ class LogContextScopeManagerTestCase(TestCase): self.assertEqual(self._tracer.active_span, scope.span) scope.close() - async def root(): + async def root() -> None: with start_active_span("root span", tracer=self._tracer) as root_scope: self.assertEqual(self._tracer.active_span, root_scope.span) scopes.append(root_scope) diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index b0d046fe00..c08954d887 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.test.proto_helpers import AccumulatingProtocol +from typing import Tuple + +from twisted.internet.protocol import Protocol +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from synapse.logging import RemoteHandler @@ -20,7 +23,9 @@ from tests.server import FakeTransport, get_clock from tests.unittest import TestCase -def connect_logging_client(reactor, client_id): +def connect_logging_client( + reactor: MemoryReactorClock, client_id: int +) -> Tuple[Protocol, AccumulatingProtocol]: # This is essentially tests.server.connect_client, but disabling autoflush on # the client transport. This is necessary to avoid an infinite loop due to # sending of data via the logging transport causing additional logs to be @@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id): class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, _ = get_clock() - def test_log_output(self): + def test_log_output(self) -> None: """ The remote handler delivers logs over TCP. """ @@ -51,6 +56,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): client, server = connect_logging_client(self.reactor, 0) # Trigger data being sent + assert isinstance(client.transport, FakeTransport) client.transport.flush() # One log message, with a single trailing newline @@ -61,7 +67,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Ensure the data passed through properly. self.assertEqual(logs[0], "Hello there, wally!") - def test_log_backpressure_debug(self): + def test_log_backpressure_debug(self) -> None: """ When backpressure is hit, DEBUG logs will be shed. """ @@ -83,6 +89,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # Only the 7 infos made it through, the debugs were elided @@ -90,7 +97,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(len(logs), 7) self.assertNotIn(b"debug", server.data) - def test_log_backpressure_info(self): + def test_log_backpressure_info(self) -> None: """ When backpressure is hit, DEBUG and INFO logs will be shed. """ @@ -116,6 +123,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The 10 warnings made it through, the debugs and infos were elided @@ -124,7 +132,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): self.assertNotIn(b"debug", server.data) self.assertNotIn(b"info", server.data) - def test_log_backpressure_cut_middle(self): + def test_log_backpressure_cut_middle(self) -> None: """ When backpressure is hit, and no more DEBUG and INFOs cannot be culled, it will cut the middle messages out. @@ -140,6 +148,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) + assert isinstance(client.transport, FakeTransport) client.transport.flush() # The first five and last five warnings made it through, the debugs and @@ -151,7 +160,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): logs, ) - def test_cancel_connection(self): + def test_cancel_connection(self) -> None: """ Gracefully handle the connection being cancelled. """ diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 0b0d8737c1..fa27f1279a 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -14,24 +14,28 @@ import json import logging from io import BytesIO, StringIO +from typing import cast from unittest.mock import Mock, patch +from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.http.site import SynapseRequest from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging.context import LoggingContext, LoggingContextFilter +from synapse.types import JsonDict from tests.logging import LoggerCleanupMixin -from tests.server import FakeChannel +from tests.server import FakeChannel, get_clock from tests.unittest import TestCase class TerseJsonTestCase(LoggerCleanupMixin, TestCase): - def setUp(self): + def setUp(self) -> None: self.output = StringIO() + self.reactor, _ = get_clock() - def get_log_line(self): + def get_log_line(self) -> JsonDict: # One log message, with a single trailing newline. data = self.output.getvalue() logs = data.splitlines() @@ -39,7 +43,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(data.count("\n"), 1) return json.loads(logs[0]) - def test_terse_json_output(self): + def test_terse_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -61,7 +65,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_extra_data(self): + def test_extra_data(self) -> None: """ Additional information can be included in the structured logging. """ @@ -93,7 +97,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["int"], 3) self.assertIs(log["bool"], True) - def test_json_output(self): + def test_json_output(self) -> None: """ The Terse JSON formatter converts log messages to JSON. """ @@ -114,7 +118,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - def test_with_context(self): + def test_with_context(self) -> None: """ The logging context should be added to the JSON response. """ @@ -139,7 +143,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "name") - def test_with_request_context(self): + def test_with_request_context(self) -> None: """ Information from the logging context request should be added to the JSON response. """ @@ -154,11 +158,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.server_version_string = "Server v1" site.reactor = Mock() site.experimental_cors_msc3886 = False - request = SynapseRequest(FakeChannel(site, None), site) + request = SynapseRequest( + cast(HTTPChannel, FakeChannel(site, self.reactor)), site + ) # Call requestReceived to finish instantiating the object. request.content = BytesIO() - # Partially skip some of the internal processing of SynapseRequest. - request._started_processing = Mock() + # Partially skip some internal processing of SynapseRequest. + request._started_processing = Mock() # type: ignore[assignment] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") @@ -200,7 +206,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): self.assertEqual(log["protocol"], "1.1") self.assertEqual(log["user_agent"], "") - def test_with_exception(self): + def test_with_exception(self) -> None: """ The logging exception type & value should be added to the JSON response. """ diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 9c17a42b65..fda48d9f61 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from unittest.mock import patch +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin @@ -46,35 +50,84 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.requester = create_requester(self.alice) self.room_id = self.helper.create_room_as( - self.alice, room_version=RoomVersions.V9.identifier, tok=self.token + # This is deliberately set to V9, because we want to test the logic which + # handles stringy power levels. Stringy power levels were outlawed in V10. + self.alice, + room_version=RoomVersions.V9.identifier, + tok=self.token, ) self.event_creation_handler = self.hs.get_event_creation_handler() - def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None: - """We should convert floats and strings to integers before passing to Rust. + @parameterized.expand( + [ + # The historically-permitted bad values. Alice's notification should be + # allowed if this threshold is at or below her power level (60) + ("100", False), + ("0", True), + (12.34, True), + (60.0, True), + (67.89, False), + # Values that int(...) would not successfully cast should be ignored. + # The room notification level should then default to 50, per the spec, so + # Alice's notification is allowed. + (None, True), + # We haven't seen `"room": []` or `"room": {}` in the wild (yet), but + # let's check them for paranoia's sake. + ([], True), + ({}, True), + ] + ) + def test_action_for_event_by_user_handles_noninteger_room_power_levels( + self, bad_room_level: object, should_permit: bool + ) -> None: + """We should convert strings in `room` to integers before passing to Rust. + + Test this as follows: + - Create a room as Alice and invite two other users Bob and Charlie. + - Set PLs so that Alice has PL 60 and `notifications.room` is set to a bad value. + - Have Alice create a message notifying @room. + - Evaluate notification actions for that message. This should not raise. + - Look in the DB to see if that message triggered a highlight for Bob. + + The test is parameterised with two arguments: + - the bad power level value for "room", before JSON serisalistion + - whether Bob should expect the message to be highlighted Reproduces #14060. A lack of validation: the gift that keeps on giving. """ + # Join another user to the room, so that there is someone to see Alice's + # @room notification. + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + self.helper.join(self.room_id, bob, tok=bob_token) - # Alter the power levels in that room to include stringy and floaty levels. - # We need to suppress the validation logic or else it will reject these dodgy - # values. (Presumably this validation was not always present.) + # Alter the power levels in that room to include the bad @room notification + # level. We need to suppress + # + # - canonicaljson validation, because canonicaljson forbids floats; + # - the event jsonschema validation, because it will forbid bad values; and + # - the auth rules checks, because they stop us from creating power levels + # with `"room": null`. (We want to test this case, because we have seen it + # in the wild.) + # + # We have seen stringy and null values for "room" in the wild, so presumably + # some of this validation was missing in the past. with patch("synapse.events.validator.validate_canonicaljson"), patch( "synapse.events.validator.jsonschema.validate" - ): - self.helper.send_state( + ), patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"): + pl_event_id = self.helper.send_state( self.room_id, "m.room.power_levels", { - "users": {self.alice: "100"}, # stringy - "notifications": {"room": 100.0}, # float + "users": {self.alice: 60}, + "notifications": {"room": bad_room_level}, }, self.token, state_key="", - ) + )["event_id"] # Create a new message event, and try to evaluate it under the dodgy # power level event. @@ -86,10 +139,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): "room_id": self.room_id, "content": { "msgtype": "m.text", - "body": "helo", + "body": "helo @room", }, "sender": self.alice, }, + prev_event_ids=[pl_event_id], ) ) @@ -97,6 +151,21 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # should not raise self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) + # Did Bob see Alice's @room notification? + highlighted_actions = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + "user_id": bob, + "highlight": 1, + }, + retcols=("*",), + desc="get_event_push_actions_staging", + ) + ) + self.assertEqual(len(highlighted_actions), int(should_permit)) + @override_config({"push": {"enabled": False}}) def test_action_for_event_by_user_disabled_by_config(self) -> None: """Ensure that push rules are not calculated when disabled in the config""" @@ -126,3 +195,89 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) bulk_evaluator._action_for_event_by_user.assert_not_called() + + @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + def test_mentions(self) -> None: + """Test the behavior of an event which includes invalid mentions.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + + sentinel = object() + + def create_and_process(mentions: Any = sentinel) -> bool: + """Returns true iff the `mentions` trigger an event push action.""" + content = {} + if mentions is not sentinel: + content[EventContentFields.MSC3952_MENTIONS] = mentions + + # Create a new message event which should cause a notification. + event, context = self.get_success( + self.event_creation_handler.create_event( + self.requester, + { + "type": "test", + "room_id": self.room_id, + "content": content, + "sender": f"@bob:{self.hs.hostname}", + }, + ) + ) + + # Ensure no actions are generated! + self.get_success( + bulk_evaluator.action_for_events_by_user([(event, context)]) + ) + + # If any actions are generated for this event, return true. + result = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcols=("*",), + desc="get_event_push_actions_staging", + ) + ) + return len(result) > 0 + + # Not including the mentions field should not notify. + self.assertFalse(create_and_process()) + # An empty mentions field should not notify. + self.assertFalse(create_and_process({})) + + # Non-dict mentions should be ignored. + mentions: Any + for mentions in (None, True, False, 1, "foo", []): + self.assertFalse(create_and_process(mentions)) + + # A non-list should be ignored. + for mentions in (None, True, False, 1, "foo", {}): + self.assertFalse(create_and_process({"user_ids": mentions})) + + # The Matrix ID appearing anywhere in the list should notify. + self.assertTrue(create_and_process({"user_ids": [self.alice]})) + self.assertTrue(create_and_process({"user_ids": ["@another:test", self.alice]})) + + # Duplicate user IDs should notify. + self.assertTrue(create_and_process({"user_ids": [self.alice, self.alice]})) + + # Invalid entries in the list are ignored. + self.assertFalse(create_and_process({"user_ids": [None, True, False, {}, []]})) + self.assertTrue( + create_and_process({"user_ids": [None, True, False, {}, [], self.alice]}) + ) + + # Room mentions from those without power should not notify. + self.assertFalse(create_and_process({"room": True})) + + # Room mentions from those with power should notify. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"notifications": {"room": 0}}, + self.token, + state_key="", + ) + self.assertTrue(create_and_process({"room": True})) + + # Invalid data should not notify. + for mentions in (None, False, 1, "foo", [], {}): + self.assertFalse(create_and_process({"room": mentions})) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 1b87756b75..9d01c989d4 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, Set, Union, cast import frozendict @@ -39,7 +39,12 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): def _get_evaluator( - self, content: JsonMapping, related_events: Optional[JsonDict] = None + self, + content: JsonMapping, + *, + user_mentions: Optional[Set[str]] = None, + room_mention: bool = False, + related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( { @@ -57,13 +62,15 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluator( _flatten_dict(event), + user_mentions or set(), + room_mention, room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), {} if related_events is None else related_events, - True, - event.room_version.msc3931_push_features, - True, + related_event_match_enabled=True, + room_version_feature_flags=event.room_version.msc3931_push_features, + msc3931_enabled=True, ) def test_display_name(self) -> None: @@ -90,6 +97,51 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): # A display name with spaces should work fine. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + def test_user_mentions(self) -> None: + """Check for user mentions.""" + condition = {"kind": "org.matrix.msc3952.is_user_mention"} + + # No mentions shouldn't match. + evaluator = self._get_evaluator({}) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # An empty set shouldn't match + evaluator = self._get_evaluator({}, user_mentions=set()) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # The Matrix ID appearing anywhere in the mentions list should match + evaluator = self._get_evaluator({}, user_mentions={"@user:test"}) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + evaluator = self._get_evaluator( + {}, user_mentions={"@another:test", "@user:test"} + ) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + + def test_room_mentions(self) -> None: + """Check for room mentions.""" + condition = {"kind": "org.matrix.msc3952.is_room_mention"} + + # No room mention shouldn't match. + evaluator = self._get_evaluator({}) + self.assertFalse(evaluator.matches(condition, None, None)) + + # Room mention should match. + evaluator = self._get_evaluator({}, room_mention=True) + self.assertTrue(evaluator.matches(condition, None, None)) + + # A room mention and user mention is valid. + evaluator = self._get_evaluator( + {}, user_mentions={"@another:test"}, room_mention=True + ) + self.assertTrue(evaluator.matches(condition, None, None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + def _assert_matches( self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: @@ -308,7 +360,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message", @@ -408,7 +460,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message", diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index c9afa0f3dd..b9047194dd 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -294,9 +294,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.make_request("GET", sync_url % (access_token, next_batch)) -class SyncKnockTestCase( - unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin -): +class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 21a1ca2a68..3086e1b565 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -13,18 +13,22 @@ # limitations under the License. from http import HTTPStatus +from typing import Any, Generator, Tuple, cast from unittest.mock import Mock, call -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as _reactor from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache +from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable from tests.utils import MockClock +reactor = cast(ISynapseReactor, _reactor) + class HttpTransactionCacheTestCase(unittest.TestCase): def setUp(self) -> None: @@ -34,11 +38,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.hs.get_auth = Mock() self.cache = HttpTransactionCache(self.hs) - self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!") + self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"}) self.mock_key = "foo" @defer.inlineCallbacks - def test_executes_given_function(self): + def test_executes_given_function( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" @@ -47,7 +53,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertEqual(res, self.mock_http_response) @defer.inlineCallbacks - def test_deduplicates_based_on_key(self): + def test_deduplicates_based_on_key( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( @@ -58,18 +66,20 @@ class HttpTransactionCacheTestCase(unittest.TestCase): cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0) @defer.inlineCallbacks - def test_logcontexts_with_async_result(self): + def test_logcontexts_with_async_result( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks - def cb(): + def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: yield Clock(reactor).sleep(0) - return "yay" + return 1, {} @defer.inlineCallbacks - def test(): + def test() -> Generator["defer.Deferred[Any]", object, None]: with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertIs(current_context(), c1) - self.assertEqual(res, "yay") + self.assertEqual(res, (1, {})) # run the test twice in parallel d = defer.gatherResults([test(), test()]) @@ -78,13 +88,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks - def test_does_not_cache_exceptions(self): + def test_does_not_cache_exceptions( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback throws an exception, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -104,13 +116,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_does_not_cache_failures(self): + def test_does_not_cache_failures( + self, + ) -> Generator["defer.Deferred[Any]", object, None]: """Checks that, if the callback returns a failure, it is called again for the next request. """ called = [False] - def cb(): + def cb() -> "defer.Deferred[Tuple[int, JsonDict]]": if called[0]: # return a valid result the second time return defer.succeed(self.mock_http_response) @@ -130,7 +144,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): self.assertIs(current_context(), test_context) @defer.inlineCallbacks - def test_cleans_up(self): + def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 543cce6b3e..8cd7c89ca2 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import ( DatabasePool, + LoggingDatabaseConnection, LoggingTransaction, make_tuple_comparison_clause, ) @@ -37,6 +38,101 @@ class TupleComparisonClauseTestCase(unittest.TestCase): self.assertEqual(args, [1, 2]) +class ExecuteScriptTestCase(unittest.HomeserverTestCase): + """Tests for `BaseDatabaseEngine.executescript` implementations.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + self.get_success( + self.db_pool.runInteraction( + "create", + lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"), + ) + ) + + def test_transaction(self) -> None: + """Test that all statements are run in a single transaction.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_transaction") + self.db_pool.engine.executescript( + cur, + ";".join( + [ + "INSERT INTO foo (name) VALUES ('transaction test')", + # This next statement will fail. When `executescript` is not + # transactional, the previous row will be observed later. + "INSERT INTO foo (name) VALUES ('transaction test')", + ] + ), + ) + + self.get_failure( + self.db_pool.runWithConnection(run), + self.db_pool.engine.module.IntegrityError, + ) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "transaction test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not running statements inside a transaction", + ) + + def test_commit(self) -> None: + """Test that the script transaction remains open and can be committed.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_commit") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('commit test')" + ) + cur.execute("COMMIT") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNotNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "commit test"}, + retcol="name", + allow_none=True, + ) + ), + ) + + def test_rollback(self) -> None: + """Test that the script transaction remains open and can be rolled back.""" + + def run(conn: LoggingDatabaseConnection) -> None: + cur = conn.cursor(txn_name="test_rollback") + self.db_pool.engine.executescript( + cur, "INSERT INTO foo (name) VALUES ('rollback test')" + ) + cur.execute("ROLLBACK") + + self.get_success(self.db_pool.runWithConnection(run)) + + self.assertIsNone( + self.get_success( + self.db_pool.simple_select_one_onecol( + "foo", + keyvalues={"name": "rollback test"}, + retcol="name", + allow_none=True, + ) + ), + "executescript is not leaving the script transaction open", + ) + + class CallbacksTestCase(unittest.HomeserverTestCase): """Tests for transaction callbacks.""" diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 3ba896ecf3..f1ca523d23 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -28,6 +28,7 @@ from synapse.storage.background_updates import _BackgroundUpdateHandler from synapse.storage.roommember import ProfileInfo from synapse.util import Clock +from tests.server import ThreadedMemoryReactorClock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config @@ -138,7 +139,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): register.register_servlets, ] - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: self.appservice = ApplicationService( token="i_am_an_app_service", id="1234", diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index f4d9fba0a1..0a7937f1cc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -from typing import Collection, Dict, Iterable, List, Optional +from typing import Any, Collection, Dict, Iterable, List, Optional from parameterized import parameterized @@ -728,6 +728,36 @@ class EventAuthTestCase(unittest.TestCase): pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event} ) + def test_room_v10_rejects_other_non_integer_power_levels(self) -> None: + """We should reject PLs that are non-integer, non-string JSON values. + + test_room_v10_rejects_string_power_levels above handles the string case. + """ + + def create_event(pl_event_content: Dict[str, Any]) -> EventBase: + return make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + contents: Iterable[Dict[str, Any]] = [ + {"notifications": {"room": None}}, + {"users": {"@alice:wonderland": []}}, + {"users_default": {}}, + ] + for content in contents: + event = create_event(content) + with self.assertRaises(SynapseError): + event_auth._check_power_levels(event.room_version, event, {}) + # helpers for making events TEST_DOMAIN = "example.com" diff --git a/tests/unittest.py b/tests/unittest.py index a120c2976c..fa92dd94eb 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -75,6 +75,7 @@ from synapse.util.httpresourcetree import create_resource_tree from tests.server import ( CustomHeaderType, FakeChannel, + ThreadedMemoryReactorClock, get_clock, make_request, setup_test_homeserver, @@ -360,7 +361,7 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor: MemoryReactor, clock: Clock): + def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock): """ Make and return a homeserver. |