diff --git a/Cargo.lock b/Cargo.lock
index d6f9000138..e3e63fc205 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -13,9 +13,9 @@ dependencies = [
[[package]]
name = "anyhow"
-version = "1.0.83"
+version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3"
+checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da"
[[package]]
name = "arc-swap"
@@ -485,18 +485,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
-version = "1.0.201"
+version = "1.0.203"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "780f1cebed1629e4753a1a38a3c72d30b97ec044f0aef68cb26650a3c5cf363c"
+checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
-version = "1.0.201"
+version = "1.0.203"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c5e405930b9796f1c00bee880d03fc7e0bb4b9a11afc776885ffe84320da2865"
+checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba"
dependencies = [
"proc-macro2",
"quote",
diff --git a/changelog.d/17147.feature b/changelog.d/17147.feature
new file mode 100644
index 0000000000..7c2cdb6bdf
--- /dev/null
+++ b/changelog.d/17147.feature
@@ -0,0 +1 @@
+Add the ability to auto-accept invites on the behalf of users. See the [`auto_accept_invites`](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#auto-accept-invites) config option for details.
diff --git a/changelog.d/17167.feature b/changelog.d/17167.feature
new file mode 100644
index 0000000000..5ad31db974
--- /dev/null
+++ b/changelog.d/17167.feature
@@ -0,0 +1 @@
+Add experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync/e2ee` endpoint for To-Device messages and device encryption info.
diff --git a/changelog.d/17176.misc b/changelog.d/17176.misc
new file mode 100644
index 0000000000..cc9f2a5202
--- /dev/null
+++ b/changelog.d/17176.misc
@@ -0,0 +1 @@
+Log exceptions when failing to auto-join new user according to the `auto_join_rooms` option.
\ No newline at end of file
diff --git a/changelog.d/17204.doc b/changelog.d/17204.doc
new file mode 100644
index 0000000000..5a5a8f5107
--- /dev/null
+++ b/changelog.d/17204.doc
@@ -0,0 +1 @@
+Update OIDC documentation: by default Matrix doesn't query userinfo endpoint, then claims should be put on id_token.
diff --git a/changelog.d/17211.misc b/changelog.d/17211.misc
new file mode 100644
index 0000000000..144db03a40
--- /dev/null
+++ b/changelog.d/17211.misc
@@ -0,0 +1 @@
+Reduce work of calculating outbound device lists updates.
diff --git a/changelog.d/17213.feature b/changelog.d/17213.feature
new file mode 100644
index 0000000000..ca60afa8f3
--- /dev/null
+++ b/changelog.d/17213.feature
@@ -0,0 +1 @@
+Support MSC3916 by adding unstable media endpoints to `_matrix/client` (#17213).
\ No newline at end of file
diff --git a/changelog.d/17216.misc b/changelog.d/17216.misc
new file mode 100644
index 0000000000..bd55eeaa33
--- /dev/null
+++ b/changelog.d/17216.misc
@@ -0,0 +1 @@
+Improve performance of calculating device lists changes in `/sync`.
diff --git a/changelog.d/17219.feature b/changelog.d/17219.feature
new file mode 100644
index 0000000000..f8277a89d8
--- /dev/null
+++ b/changelog.d/17219.feature
@@ -0,0 +1 @@
+Add logging to tasks managed by the task scheduler, showing CPU and database usage.
\ No newline at end of file
diff --git a/docs/openid.md b/docs/openid.md
index 9773a7de52..7a10b1615b 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -525,6 +525,8 @@ oidc_providers:
(`Options > Security > ID Token signature algorithm` and `Options > Security >
Access Token signature algorithm`)
- Scopes: OpenID, Email and Profile
+- Force claims into `id_token`
+ (`Options > Advanced > Force claims to be returned in ID Token`)
- Allowed redirection addresses for login (`Options > Basic > Allowed
redirection addresses for login` ) :
`[synapse public baseurl]/_synapse/client/oidc/callback`
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index e04fdfdfb0..2c917d1f8e 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -4595,3 +4595,32 @@ background_updates:
min_batch_size: 10
default_batch_size: 50
```
+---
+## Auto Accept Invites
+Configuration settings related to automatically accepting invites.
+
+---
+### `auto_accept_invites`
+
+Automatically accepting invites controls whether users are presented with an invite request or if they
+are instead automatically joined to a room when receiving an invite. Set the `enabled` sub-option to true to
+enable auto-accepting invites. Defaults to false.
+This setting has the following sub-options:
+* `enabled`: Whether to run the auto-accept invites logic. Defaults to false.
+* `only_for_direct_messages`: Whether invites should be automatically accepted for all room types, or only
+ for direct messages. Defaults to false.
+* `only_from_local_users`: Whether to only automatically accept invites from users on this homeserver. Defaults to false.
+* `worker_to_run_on`: Which worker to run this module on. This must match the "worker_name".
+
+NOTE: Care should be taken not to enable this setting if the `synapse_auto_accept_invite` module is enabled and installed.
+The two modules will compete to perform the same task and may result in undesired behaviour. For example, multiple join
+events could be generated from a single invite.
+
+Example configuration:
+```yaml
+auto_accept_invites:
+ enabled: true
+ only_for_direct_messages: true
+ only_from_local_users: true
+ worker_to_run_on: "worker_1"
+```
diff --git a/poetry.lock b/poetry.lock
index 8537f37529..73814e49d0 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -67,38 +67,38 @@ visualize = ["Twisted (>=16.1.1)", "graphviz (>0.5.1)"]
[[package]]
name = "bcrypt"
-version = "4.1.2"
+version = "4.1.3"
description = "Modern password hashing for your software and your servers"
optional = false
python-versions = ">=3.7"
files = [
- {file = "bcrypt-4.1.2-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:ac621c093edb28200728a9cca214d7e838529e557027ef0581685909acd28b5e"},
- {file = "bcrypt-4.1.2-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea505c97a5c465ab8c3ba75c0805a102ce526695cd6818c6de3b1a38f6f60da1"},
- {file = "bcrypt-4.1.2-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:57fa9442758da926ed33a91644649d3e340a71e2d0a5a8de064fb621fd5a3326"},
- {file = "bcrypt-4.1.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:eb3bd3321517916696233b5e0c67fd7d6281f0ef48e66812db35fc963a422a1c"},
- {file = "bcrypt-4.1.2-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6cad43d8c63f34b26aef462b6f5e44fdcf9860b723d2453b5d391258c4c8e966"},
- {file = "bcrypt-4.1.2-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:44290ccc827d3a24604f2c8bcd00d0da349e336e6503656cb8192133e27335e2"},
- {file = "bcrypt-4.1.2-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:732b3920a08eacf12f93e6b04ea276c489f1c8fb49344f564cca2adb663b3e4c"},
- {file = "bcrypt-4.1.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1c28973decf4e0e69cee78c68e30a523be441972c826703bb93099868a8ff5b5"},
- {file = "bcrypt-4.1.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b8df79979c5bae07f1db22dcc49cc5bccf08a0380ca5c6f391cbb5790355c0b0"},
- {file = "bcrypt-4.1.2-cp37-abi3-win32.whl", hash = "sha256:fbe188b878313d01b7718390f31528be4010fed1faa798c5a1d0469c9c48c369"},
- {file = "bcrypt-4.1.2-cp37-abi3-win_amd64.whl", hash = "sha256:9800ae5bd5077b13725e2e3934aa3c9c37e49d3ea3d06318010aa40f54c63551"},
- {file = "bcrypt-4.1.2-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:71b8be82bc46cedd61a9f4ccb6c1a493211d031415a34adde3669ee1b0afbb63"},
- {file = "bcrypt-4.1.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e3c6642077b0c8092580c819c1684161262b2e30c4f45deb000c38947bf483"},
- {file = "bcrypt-4.1.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:387e7e1af9a4dd636b9505a465032f2f5cb8e61ba1120e79a0e1cd0b512f3dfc"},
- {file = "bcrypt-4.1.2-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f70d9c61f9c4ca7d57f3bfe88a5ccf62546ffbadf3681bb1e268d9d2e41c91a7"},
- {file = "bcrypt-4.1.2-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:2a298db2a8ab20056120b45e86c00a0a5eb50ec4075b6142db35f593b97cb3fb"},
- {file = "bcrypt-4.1.2-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:ba55e40de38a24e2d78d34c2d36d6e864f93e0d79d0b6ce915e4335aa81d01b1"},
- {file = "bcrypt-4.1.2-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:3566a88234e8de2ccae31968127b0ecccbb4cddb629da744165db72b58d88ca4"},
- {file = "bcrypt-4.1.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b90e216dc36864ae7132cb151ffe95155a37a14e0de3a8f64b49655dd959ff9c"},
- {file = "bcrypt-4.1.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:69057b9fc5093ea1ab00dd24ede891f3e5e65bee040395fb1e66ee196f9c9b4a"},
- {file = "bcrypt-4.1.2-cp39-abi3-win32.whl", hash = "sha256:02d9ef8915f72dd6daaef40e0baeef8a017ce624369f09754baf32bb32dba25f"},
- {file = "bcrypt-4.1.2-cp39-abi3-win_amd64.whl", hash = "sha256:be3ab1071662f6065899fe08428e45c16aa36e28bc42921c4901a191fda6ee42"},
- {file = "bcrypt-4.1.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d75fc8cd0ba23f97bae88a6ec04e9e5351ff3c6ad06f38fe32ba50cbd0d11946"},
- {file = "bcrypt-4.1.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:a97e07e83e3262599434816f631cc4c7ca2aa8e9c072c1b1a7fec2ae809a1d2d"},
- {file = "bcrypt-4.1.2-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:e51c42750b7585cee7892c2614be0d14107fad9581d1738d954a262556dd1aab"},
- {file = "bcrypt-4.1.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba4e4cc26610581a6329b3937e02d319f5ad4b85b074846bf4fef8a8cf51e7bb"},
- {file = "bcrypt-4.1.2.tar.gz", hash = "sha256:33313a1200a3ae90b75587ceac502b048b840fc69e7f7a0905b5f87fac7a1258"},
+ {file = "bcrypt-4.1.3-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:48429c83292b57bf4af6ab75809f8f4daf52aa5d480632e53707805cc1ce9b74"},
+ {file = "bcrypt-4.1.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a8bea4c152b91fd8319fef4c6a790da5c07840421c2b785084989bf8bbb7455"},
+ {file = "bcrypt-4.1.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d3b317050a9a711a5c7214bf04e28333cf528e0ed0ec9a4e55ba628d0f07c1a"},
+ {file = "bcrypt-4.1.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:094fd31e08c2b102a14880ee5b3d09913ecf334cd604af27e1013c76831f7b05"},
+ {file = "bcrypt-4.1.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:4fb253d65da30d9269e0a6f4b0de32bd657a0208a6f4e43d3e645774fb5457f3"},
+ {file = "bcrypt-4.1.3-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:193bb49eeeb9c1e2db9ba65d09dc6384edd5608d9d672b4125e9320af9153a15"},
+ {file = "bcrypt-4.1.3-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:8cbb119267068c2581ae38790e0d1fbae65d0725247a930fc9900c285d95725d"},
+ {file = "bcrypt-4.1.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:6cac78a8d42f9d120b3987f82252bdbeb7e6e900a5e1ba37f6be6fe4e3848286"},
+ {file = "bcrypt-4.1.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:01746eb2c4299dd0ae1670234bf77704f581dd72cc180f444bfe74eb80495b64"},
+ {file = "bcrypt-4.1.3-cp37-abi3-win32.whl", hash = "sha256:037c5bf7c196a63dcce75545c8874610c600809d5d82c305dd327cd4969995bf"},
+ {file = "bcrypt-4.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:8a893d192dfb7c8e883c4576813bf18bb9d59e2cfd88b68b725990f033f1b978"},
+ {file = "bcrypt-4.1.3-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0d4cf6ef1525f79255ef048b3489602868c47aea61f375377f0d00514fe4a78c"},
+ {file = "bcrypt-4.1.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5698ce5292a4e4b9e5861f7e53b1d89242ad39d54c3da451a93cac17b61921a"},
+ {file = "bcrypt-4.1.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec3c2e1ca3e5c4b9edb94290b356d082b721f3f50758bce7cce11d8a7c89ce84"},
+ {file = "bcrypt-4.1.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:3a5be252fef513363fe281bafc596c31b552cf81d04c5085bc5dac29670faa08"},
+ {file = "bcrypt-4.1.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:5f7cd3399fbc4ec290378b541b0cf3d4398e4737a65d0f938c7c0f9d5e686611"},
+ {file = "bcrypt-4.1.3-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:c4c8d9b3e97209dd7111bf726e79f638ad9224b4691d1c7cfefa571a09b1b2d6"},
+ {file = "bcrypt-4.1.3-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:31adb9cbb8737a581a843e13df22ffb7c84638342de3708a98d5c986770f2834"},
+ {file = "bcrypt-4.1.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:551b320396e1d05e49cc18dd77d970accd52b322441628aca04801bbd1d52a73"},
+ {file = "bcrypt-4.1.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6717543d2c110a155e6821ce5670c1f512f602eabb77dba95717ca76af79867d"},
+ {file = "bcrypt-4.1.3-cp39-abi3-win32.whl", hash = "sha256:6004f5229b50f8493c49232b8e75726b568535fd300e5039e255d919fc3a07f2"},
+ {file = "bcrypt-4.1.3-cp39-abi3-win_amd64.whl", hash = "sha256:2505b54afb074627111b5a8dc9b6ae69d0f01fea65c2fcaea403448c503d3991"},
+ {file = "bcrypt-4.1.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:cb9c707c10bddaf9e5ba7cdb769f3e889e60b7d4fea22834b261f51ca2b89fed"},
+ {file = "bcrypt-4.1.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:9f8ea645eb94fb6e7bea0cf4ba121c07a3a182ac52876493870033141aa687bc"},
+ {file = "bcrypt-4.1.3-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:f44a97780677e7ac0ca393bd7982b19dbbd8d7228c1afe10b128fd9550eef5f1"},
+ {file = "bcrypt-4.1.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d84702adb8f2798d813b17d8187d27076cca3cd52fe3686bb07a9083930ce650"},
+ {file = "bcrypt-4.1.3.tar.gz", hash = "sha256:2ee15dd749f5952fe3f0430d0ff6b74082e159c50332a1413d51b5689cf06623"},
]
[package.extras]
@@ -1536,13 +1536,13 @@ files = [
[[package]]
name = "phonenumbers"
-version = "8.13.35"
+version = "8.13.37"
description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers."
optional = false
python-versions = "*"
files = [
- {file = "phonenumbers-8.13.35-py2.py3-none-any.whl", hash = "sha256:58286a8e617bd75f541e04313b28c36398be6d4443a778c85e9617a93c391310"},
- {file = "phonenumbers-8.13.35.tar.gz", hash = "sha256:64f061a967dcdae11e1c59f3688649e697b897110a33bb74d5a69c3e35321245"},
+ {file = "phonenumbers-8.13.37-py2.py3-none-any.whl", hash = "sha256:4ea00ef5012422c08c7955c21131e7ae5baa9a3ef52cf2d561e963f023006b80"},
+ {file = "phonenumbers-8.13.37.tar.gz", hash = "sha256:bd315fed159aea0516f7c367231810fe8344d5bec26156b88fa18374c11d1cf2"},
]
[[package]]
@@ -1673,13 +1673,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytes
[[package]]
name = "prometheus-client"
-version = "0.19.0"
+version = "0.20.0"
description = "Python client for the Prometheus monitoring system."
optional = false
python-versions = ">=3.8"
files = [
- {file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"},
- {file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"},
+ {file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"},
+ {file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"},
]
[package.extras]
@@ -1736,13 +1736,13 @@ psycopg2 = "*"
[[package]]
name = "pyasn1"
-version = "0.5.1"
+version = "0.6.0"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = false
-python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
+python-versions = ">=3.8"
files = [
- {file = "pyasn1-0.5.1-py2.py3-none-any.whl", hash = "sha256:4439847c58d40b1d0a573d07e3856e95333f1976294494c325775aeca506eb58"},
- {file = "pyasn1-0.5.1.tar.gz", hash = "sha256:6d391a96e59b23130a5cfa74d6fd7f388dbbe26cc8f1edf39fdddf08d9d6676c"},
+ {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"},
+ {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"},
]
[[package]]
@@ -1915,12 +1915,12 @@ plugins = ["importlib-metadata"]
[[package]]
name = "pyicu"
-version = "2.13"
+version = "2.13.1"
description = "Python extension wrapping the ICU C++ API"
optional = true
python-versions = "*"
files = [
- {file = "PyICU-2.13.tar.gz", hash = "sha256:d481be888975df3097c2790241bbe8518f65c9676a74957cdbe790e559c828f6"},
+ {file = "PyICU-2.13.1.tar.gz", hash = "sha256:d4919085eaa07da12bade8ee721e7bbf7ade0151ca0f82946a26c8f4b98cdceb"},
]
[[package]]
@@ -1997,13 +1997,13 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]]
name = "pyopenssl"
-version = "24.0.0"
+version = "24.1.0"
description = "Python wrapper module around the OpenSSL library"
optional = false
python-versions = ">=3.7"
files = [
- {file = "pyOpenSSL-24.0.0-py3-none-any.whl", hash = "sha256:ba07553fb6fd6a7a2259adb9b84e12302a9a8a75c44046e8bb5d3e5ee887e3c3"},
- {file = "pyOpenSSL-24.0.0.tar.gz", hash = "sha256:6aa33039a93fffa4563e655b61d11364d01264be8ccb49906101e02a334530bf"},
+ {file = "pyOpenSSL-24.1.0-py3-none-any.whl", hash = "sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad"},
+ {file = "pyOpenSSL-24.1.0.tar.gz", hash = "sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f"},
]
[package.dependencies]
@@ -2011,7 +2011,7 @@ cryptography = ">=41.0.5,<43"
[package.extras]
docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"]
-test = ["flaky", "pretend", "pytest (>=3.0.1)"]
+test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
[[package]]
name = "pysaml2"
@@ -2673,13 +2673,13 @@ docs = ["sphinx (<7.0.0)"]
[[package]]
name = "twine"
-version = "5.0.0"
+version = "5.1.0"
description = "Collection of utilities for publishing packages on PyPI"
optional = false
python-versions = ">=3.8"
files = [
- {file = "twine-5.0.0-py3-none-any.whl", hash = "sha256:a262933de0b484c53408f9edae2e7821c1c45a3314ff2df9bdd343aa7ab8edc0"},
- {file = "twine-5.0.0.tar.gz", hash = "sha256:89b0cc7d370a4b66421cc6102f269aa910fe0f1861c124f573cf2ddedbc10cf4"},
+ {file = "twine-5.1.0-py3-none-any.whl", hash = "sha256:fe1d814395bfe50cfbe27783cb74efe93abeac3f66deaeb6c8390e4e92bacb43"},
+ {file = "twine-5.1.0.tar.gz", hash = "sha256:4d74770c88c4fcaf8134d2a6a9d863e40f08255ff7d8e2acb3cbbd57d25f6e9d"},
]
[package.dependencies]
@@ -2853,13 +2853,13 @@ files = [
[[package]]
name = "types-psycopg2"
-version = "2.9.21.20240311"
+version = "2.9.21.20240417"
description = "Typing stubs for psycopg2"
optional = false
python-versions = ">=3.8"
files = [
- {file = "types-psycopg2-2.9.21.20240311.tar.gz", hash = "sha256:722945dffa6a729bebc660f14137f37edfcead5a2c15eb234212a7d017ee8072"},
- {file = "types_psycopg2-2.9.21.20240311-py3-none-any.whl", hash = "sha256:2e137ae2b516ee0dbaab6f555086b6cfb723ba4389d67f551b0336adf4efcf1b"},
+ {file = "types-psycopg2-2.9.21.20240417.tar.gz", hash = "sha256:05db256f4a459fb21a426b8e7fca0656c3539105ff0208eaf6bdaf406a387087"},
+ {file = "types_psycopg2-2.9.21.20240417-py3-none-any.whl", hash = "sha256:644d6644d64ebbe37203229b00771012fb3b3bddd507a129a2e136485990e4f8"},
]
[[package]]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3182608f73..67e0df1459 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -68,6 +68,7 @@ from synapse.config._base import format_config_error
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ListenerConfig, ManholeConfig, TCPListenerConfig
from synapse.crypto import context_factory
+from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.events.presence_router import load_legacy_presence_router
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseSite
@@ -582,6 +583,11 @@ async def start(hs: "HomeServer") -> None:
m = module(config, module_api)
logger.info("Loaded module %s", m)
+ if hs.config.auto_accept_invites.enabled:
+ # Start the local auto_accept_invites module.
+ m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+ logger.info("Loaded local module %s", m)
+
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index fc51aed234..d9cb0da38b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -23,6 +23,7 @@ from synapse.config import ( # noqa: F401
api,
appservice,
auth,
+ auto_accept_invites,
background_updates,
cache,
captcha,
@@ -120,6 +121,7 @@ class RootConfig:
federation: federation.FederationConfig
retention: retention.RetentionConfig
background_updates: background_updates.BackgroundUpdateConfig
+ auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
config_classes: List[Type["Config"]] = ...
config_files: List[str]
diff --git a/synapse/config/auto_accept_invites.py b/synapse/config/auto_accept_invites.py
new file mode 100644
index 0000000000..d90e13a510
--- /dev/null
+++ b/synapse/config/auto_accept_invites.py
@@ -0,0 +1,43 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+from typing import Any
+
+from synapse.types import JsonDict
+
+from ._base import Config
+
+
+class AutoAcceptInvitesConfig(Config):
+ section = "auto_accept_invites"
+
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+ auto_accept_invites_config = config.get("auto_accept_invites") or {}
+
+ self.enabled = auto_accept_invites_config.get("enabled", False)
+
+ self.accept_invites_only_for_direct_messages = auto_accept_invites_config.get(
+ "only_for_direct_messages", False
+ )
+
+ self.accept_invites_only_from_local_users = auto_accept_invites_config.get(
+ "only_from_local_users", False
+ )
+
+ self.worker_to_run_on = auto_accept_invites_config.get("worker_to_run_on")
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 749452ce93..75fe6d7b24 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
+ # MSC3575 (Sliding Sync API endpoints)
+ self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
+
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
@@ -436,3 +439,7 @@ class ExperimentalConfig(Config):
self.msc4115_membership_on_events = experimental.get(
"msc4115_membership_on_events", False
)
+
+ self.msc3916_authenticated_media_enabled = experimental.get(
+ "msc3916_authenticated_media_enabled", False
+ )
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 72e93ed04f..e36c0bd6ae 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .account_validity import AccountValidityConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .auth import AuthConfig
+from .auto_accept_invites import AutoAcceptInvitesConfig
from .background_updates import BackgroundUpdateConfig
from .cache import CacheConfig
from .captcha import CaptchaConfig
@@ -105,4 +106,5 @@ class HomeServerConfig(RootConfig):
RedisConfig,
ExperimentalConfig,
BackgroundUpdateConfig,
+ AutoAcceptInvitesConfig,
]
diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py
new file mode 100644
index 0000000000..d88ec51d9d
--- /dev/null
+++ b/synapse/events/auto_accept_invites.py
@@ -0,0 +1,196 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import logging
+from http import HTTPStatus
+from typing import Any, Dict, Tuple
+
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.module_api import EventBase, ModuleApi, run_as_background_process
+
+logger = logging.getLogger(__name__)
+
+
+class InviteAutoAccepter:
+ def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
+ # Keep a reference to the Module API.
+ self._api = api
+ self._config = config
+
+ if not self._config.enabled:
+ return
+
+ should_run_on_this_worker = config.worker_to_run_on == self._api.worker_name
+
+ if not should_run_on_this_worker:
+ logger.info(
+ "Not accepting invites on this worker (configured: %r, here: %r)",
+ config.worker_to_run_on,
+ self._api.worker_name,
+ )
+ return
+
+ logger.info(
+ "Accepting invites on this worker (here: %r)", self._api.worker_name
+ )
+
+ # Register the callback.
+ self._api.register_third_party_rules_callbacks(
+ on_new_event=self.on_new_event,
+ )
+
+ async def on_new_event(self, event: EventBase, *args: Any) -> None:
+ """Listens for new events, and if the event is an invite for a local user then
+ automatically accepts it.
+
+ Args:
+ event: The incoming event.
+ """
+ # Check if the event is an invite for a local user.
+ is_invite_for_local_user = (
+ event.type == EventTypes.Member
+ and event.is_state()
+ and event.membership == Membership.INVITE
+ and self._api.is_mine(event.state_key)
+ )
+
+ # Only accept invites for direct messages if the configuration mandates it.
+ is_direct_message = event.content.get("is_direct", False)
+ is_allowed_by_direct_message_rules = (
+ not self._config.accept_invites_only_for_direct_messages
+ or is_direct_message is True
+ )
+
+ # Only accept invites from remote users if the configuration mandates it.
+ is_from_local_user = self._api.is_mine(event.sender)
+ is_allowed_by_local_user_rules = (
+ not self._config.accept_invites_only_from_local_users
+ or is_from_local_user is True
+ )
+
+ if (
+ is_invite_for_local_user
+ and is_allowed_by_direct_message_rules
+ and is_allowed_by_local_user_rules
+ ):
+ # Make the user join the room. We run this as a background process to circumvent a race condition
+ # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
+ run_as_background_process(
+ "retry_make_join",
+ self._retry_make_join,
+ event.state_key,
+ event.state_key,
+ event.room_id,
+ "join",
+ bg_start_span=False,
+ )
+
+ if is_direct_message:
+ # Mark this room as a direct message!
+ await self._mark_room_as_direct_message(
+ event.state_key, event.sender, event.room_id
+ )
+
+ async def _mark_room_as_direct_message(
+ self, user_id: str, dm_user_id: str, room_id: str
+ ) -> None:
+ """
+ Marks a room (`room_id`) as a direct message with the counterparty `dm_user_id`
+ from the perspective of the user `user_id`.
+
+ Args:
+ user_id: the user for whom the membership is changing
+ dm_user_id: the user performing the membership change
+ room_id: room id of the room the user is invited to
+ """
+
+ # This is a dict of User IDs to tuples of Room IDs
+ # (get_global will return a frozendict of tuples as it freezes the data,
+ # but we should accept either frozen or unfrozen variants.)
+ # Be careful: we convert the outer frozendict into a dict here,
+ # but the contents of the dict are still frozen (tuples in lieu of lists,
+ # etc.)
+ dm_map: Dict[str, Tuple[str, ...]] = dict(
+ await self._api.account_data_manager.get_global(
+ user_id, AccountDataTypes.DIRECT
+ )
+ or {}
+ )
+
+ if dm_user_id not in dm_map:
+ dm_map[dm_user_id] = (room_id,)
+ else:
+ dm_rooms_for_user = dm_map[dm_user_id]
+ assert isinstance(dm_rooms_for_user, (tuple, list))
+
+ dm_map[dm_user_id] = tuple(dm_rooms_for_user) + (room_id,)
+
+ await self._api.account_data_manager.put_global(
+ user_id, AccountDataTypes.DIRECT, dm_map
+ )
+
+ async def _retry_make_join(
+ self, sender: str, target: str, room_id: str, new_membership: str
+ ) -> None:
+ """
+ A function to retry sending the `make_join` request with an increasing backoff. This is
+ implemented to work around a race condition when receiving invites over federation.
+
+ Args:
+ sender: the user performing the membership change
+ target: the user for whom the membership is changing
+ room_id: room id of the room to join to
+ new_membership: the type of membership event (in this case will be "join")
+ """
+
+ sleep = 0
+ retries = 0
+ join_event = None
+
+ while retries < 5:
+ try:
+ await self._api.sleep(sleep)
+ join_event = await self._api.update_room_membership(
+ sender=sender,
+ target=target,
+ room_id=room_id,
+ new_membership=new_membership,
+ )
+ except SynapseError as e:
+ if e.code == HTTPStatus.FORBIDDEN:
+ logger.debug(
+ f"Update_room_membership was forbidden. This can sometimes be expected for remote invites. Exception: {e}"
+ )
+ else:
+ logger.warn(
+ f"Update_room_membership raised the following unexpected (SynapseError) exception: {e}"
+ )
+ except Exception as e:
+ logger.warn(
+ f"Update_room_membership raised the following unexpected exception: {e}"
+ )
+
+ sleep = 2**retries
+ retries += 1
+
+ if join_event is not None:
+ break
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 67953a3ed9..0432d97109 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -159,20 +159,32 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
- self, user_id: str, room_ids: StrCollection, from_token: StreamToken
+ self,
+ user_id: str,
+ room_ids: StrCollection,
+ from_token: StreamToken,
+ now_token: Optional[StreamToken] = None,
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
+ now_device_lists_key = self.store.get_device_stream_token()
+ if now_token:
+ now_device_lists_key = now_token.device_list_key
+
changed_users = await self.store.get_device_list_changes_in_rooms(
- room_ids, from_token.device_list_key
+ room_ids,
+ from_token.device_list_key,
+ now_device_lists_key,
)
if changed_users is not None:
# We also check if the given user has changed their device. If
# they're in no rooms then the above query won't include them.
changed = await self.store.get_users_whose_devices_changed(
- from_token.device_list_key, [user_id]
+ from_token.device_list_key,
+ [user_id],
+ to_key=now_device_lists_key,
)
changed_users.update(changed)
return changed_users
@@ -190,7 +202,9 @@ class DeviceWorkerHandler:
tracked_users.add(user_id)
changed = await self.store.get_users_whose_devices_changed(
- from_token.device_list_key, tracked_users
+ from_token.device_list_key,
+ tracked_users,
+ to_key=now_device_lists_key,
)
return changed
@@ -892,6 +906,13 @@ class DeviceHandler(DeviceWorkerHandler):
context=opentracing_context,
)
+ await self.store.mark_redundant_device_lists_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ converted_upto_stream_id=stream_id,
+ )
+
# Notify replication that we've updated the device list stream.
self.notifier.notify_replication()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e48e70db04..c200e29569 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -590,7 +590,7 @@ class RegistrationHandler:
# moving away from bare excepts is a good thing to do.
logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
- logger.error("Failed to join new user to %r: %r", r, e)
+ logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index f275d4f35a..ee74289b6c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -817,7 +817,7 @@ class SsoHandler:
server_name = profile["avatar_url"].split("/")[-2]
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
- media = await self._media_repo.store.get_local_media(media_id)
+ media = await self._media_repo.store.get_local_media(media_id) # type: ignore[has-type]
if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d3d40e8682..ac5bddd52f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,11 +28,14 @@ from typing import (
Dict,
FrozenSet,
List,
+ Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
+ Union,
+ overload,
)
import attr
@@ -128,6 +131,8 @@ class SyncVersion(Enum):
# Traditional `/sync` endpoint
SYNC_V2 = "sync_v2"
+ # Part of MSC3575 Sliding Sync
+ E2EE_SYNC = "e2ee_sync"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -280,6 +285,26 @@ class SyncResult:
)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class E2eeSyncResult:
+ """
+ Attributes:
+ next_batch: Token for the next sync
+ to_device: List of direct messages for the device.
+ device_lists: List of user_ids whose devices have changed
+ device_one_time_keys_count: Dict of algorithm to count for one time keys
+ for this device
+ device_unused_fallback_key_types: List of key types that have an unused fallback
+ key
+ """
+
+ next_batch: StreamToken
+ to_device: List[JsonDict]
+ device_lists: DeviceListUpdates
+ device_one_time_keys_count: JsonMapping
+ device_unused_fallback_key_types: List[str]
+
+
class SyncHandler:
def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
@@ -322,6 +347,31 @@ class SyncHandler:
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
+ @overload
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> SyncResult: ...
+
+ @overload
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> E2eeSyncResult: ...
+
+ @overload
async def wait_for_sync_for_user(
self,
requester: Requester,
@@ -331,7 +381,18 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
- ) -> SyncResult:
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
@@ -344,8 +405,10 @@ class SyncHandler:
since_token: The point in the stream to sync from.
timeout: How long to wait for new data to arrive before giving up.
full_state: Whether to return the full state for each room.
+
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+ When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
@@ -366,6 +429,29 @@ class SyncHandler:
logger.debug("Returning sync response for %s", user_id)
return res
+ @overload
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> SyncResult: ...
+
+ @overload
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> E2eeSyncResult: ...
+
+ @overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
@@ -374,7 +460,17 @@ class SyncHandler:
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
- ) -> SyncResult:
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> Union[SyncResult, E2eeSyncResult]:
"""The start of the machinery that produces a /sync response.
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -417,14 +513,16 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
- result: SyncResult = await self.current_sync_for_user(
- sync_config, sync_version, since_token, full_state=full_state
+ result: Union[SyncResult, E2eeSyncResult] = (
+ await self.current_sync_for_user(
+ sync_config, sync_version, since_token, full_state=full_state
+ )
)
else:
# Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
- ) -> SyncResult:
+ ) -> Union[SyncResult, E2eeSyncResult]:
return await self.current_sync_for_user(
sync_config, sync_version, since_token
)
@@ -456,14 +554,43 @@ class SyncHandler:
return result
+ @overload
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> SyncResult: ...
+
+ @overload
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> E2eeSyncResult: ...
+
+ @overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
- ) -> SyncResult:
- """Generates the response body of a sync result, represented as a SyncResult.
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]:
+ """
+ Generates the response body of a sync result, represented as a
+ `SyncResult`/`E2eeSyncResult`.
This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your
@@ -474,15 +601,25 @@ class SyncHandler:
sync_version: Determines what kind of sync response to generate.
since_token: The point in the stream to sync from.p.
full_state: Whether to return the full state for each room.
+
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+ When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
+
# Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2:
- sync_result: SyncResult = await self.generate_sync_result(
- sync_config, since_token, full_state
+ sync_result: Union[SyncResult, E2eeSyncResult] = (
+ await self.generate_sync_result(
+ sync_config, since_token, full_state
+ )
+ )
+ # Go through the MSC3575 Sliding Sync `/sync/e2ee` path
+ elif sync_version == SyncVersion.E2EE_SYNC:
+ sync_result = await self.generate_e2ee_sync_result(
+ sync_config, since_token
)
else:
raise Exception(
@@ -1691,6 +1828,96 @@ class SyncHandler:
next_batch=sync_result_builder.now_token,
)
+ async def generate_e2ee_sync_result(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ ) -> E2eeSyncResult:
+ """
+ Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
+
+ This is represented by a `E2eeSyncResult` struct, which is built from small
+ pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
+ mutable ("inout") parameter to various helper functions. These retrieve and
+ process the data which forms the sync body, often writing to the
+ `sync_result_builder` to store their output.
+
+ At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
+ instance to signify that the sync calculation is complete.
+ """
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+
+ sync_result_builder = await self.get_sync_result_builder(
+ sync_config,
+ since_token,
+ full_state=False,
+ )
+
+ # 1. Calculate `to_device` events
+ await self._generate_sync_entry_for_to_device(sync_result_builder)
+
+ # 2. Calculate `device_lists`
+ # Device list updates are sent if a since token is provided.
+ device_lists = DeviceListUpdates()
+ include_device_list_updates = bool(since_token and since_token.device_list_key)
+ if include_device_list_updates:
+ # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
+ # is used in calculate_user_changes below.
+ #
+ # TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
+ # figure out the membership changes/derived info needed for
+ # `_generate_sync_entry_for_device_list()`. In the future, we should try to
+ # refactor this away.
+ (
+ newly_joined_rooms,
+ newly_left_rooms,
+ ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
+
+ # This uses the sync_result_builder.joined which is set in
+ # `_generate_sync_entry_for_rooms`, if that didn't find any joined
+ # rooms for some reason it is a no-op.
+ (
+ newly_joined_or_invited_or_knocked_users,
+ newly_left_users,
+ ) = sync_result_builder.calculate_user_changes()
+
+ device_lists = await self._generate_sync_entry_for_device_list(
+ sync_result_builder,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
+ )
+
+ # 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
+ device_id = sync_config.device_id
+ one_time_keys_count: JsonMapping = {}
+ unused_fallback_key_types: List[str] = []
+ if device_id:
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+ one_time_keys_count = await self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+ unused_fallback_key_types = list(
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+
+ return E2eeSyncResult(
+ to_device=sync_result_builder.to_device,
+ device_lists=device_lists,
+ device_one_time_keys_count=one_time_keys_count,
+ device_unused_fallback_key_types=unused_fallback_key_types,
+ next_batch=sync_result_builder.now_token,
+ )
+
async def get_sync_result_builder(
self,
sync_config: SyncConfig,
@@ -1886,38 +2113,14 @@ class SyncHandler:
# Step 1a, check for changes in devices of users we share a room
# with
- #
- # We do this in two different ways depending on what we have cached.
- # If we already have a list of all the user that have changed since
- # the last sync then it's likely more efficient to compare the rooms
- # they're in with the rooms the syncing user is in.
- #
- # If we don't have that info cached then we get all the users that
- # share a room with our user and check if those users have changed.
- cache_result = self.store.get_cached_device_list_changes(
- since_token.device_list_key
- )
- if cache_result.hit:
- changed_users = cache_result.entities
-
- result = await self.store.get_rooms_for_users(changed_users)
-
- for changed_user_id, entries in result.items():
- # Check if the changed user shares any rooms with the user,
- # or if the changed user is the syncing user (as we always
- # want to include device list updates of their own devices).
- if user_id == changed_user_id or any(
- rid in joined_room_ids for rid in entries
- ):
- users_that_have_changed.add(changed_user_id)
- else:
- users_that_have_changed = (
- await self._device_handler.get_device_changes_in_shared_rooms(
- user_id,
- sync_result_builder.joined_room_ids,
- from_token=since_token,
- )
+ users_that_have_changed = (
+ await self._device_handler.get_device_changes_in_shared_rooms(
+ user_id,
+ joined_room_ids,
+ from_token=since_token,
+ now_token=sync_result_builder.now_token,
)
+ )
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index 5538020bec..cc3acf51e1 100644
--- a/synapse/media/thumbnailer.py
+++ b/synapse/media/thumbnailer.py
@@ -22,11 +22,27 @@
import logging
from io import BytesIO
from types import TracebackType
-from typing import Optional, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from PIL import Image
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
+from synapse.http.server import respond_with_json
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
+from synapse.media._base import (
+ FileInfo,
+ ThumbnailInfo,
+ respond_404,
+ respond_with_file,
+ respond_with_responder,
+)
+from synapse.media.media_storage import MediaStorage
+
+if TYPE_CHECKING:
+ from synapse.media.media_repository import MediaRepository
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -231,3 +247,471 @@ class Thumbnailer:
def __del__(self) -> None:
# Make sure we actually do close the image, rather than leak data.
self.close()
+
+
+class ThumbnailProvider:
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
+ self.hs = hs
+ self.media_repo = media_repo
+ self.media_storage = media_storage
+ self.store = hs.get_datastores().main
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+
+ async def respond_local_thumbnail(
+ self,
+ request: SynapseRequest,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ max_timeout_ms: int,
+ ) -> None:
+ media_info = await self.media_repo.get_local_media_info(
+ request, media_id, max_timeout_ms
+ )
+ if not media_info:
+ return
+
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ media_id,
+ url_cache=bool(media_info.url_cache),
+ server_name=None,
+ )
+
+ async def select_or_generate_local_thumbnail(
+ self,
+ request: SynapseRequest,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ max_timeout_ms: int,
+ ) -> None:
+ media_info = await self.media_repo.get_local_media_info(
+ request, media_id, max_timeout_ms
+ )
+
+ if not media_info:
+ return
+
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
+ for info in thumbnail_infos:
+ t_w = info.width == desired_width
+ t_h = info.height == desired_height
+ t_method = info.method == desired_method
+ t_type = info.type == desired_type
+
+ if t_w and t_h and t_method and t_type:
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ url_cache=bool(media_info.url_cache),
+ thumbnail=info,
+ )
+
+ responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request, responder, info.type, info.length
+ )
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
+
+ # Okay, so we generate one.
+ file_path = await self.media_repo.generate_local_exact_thumbnail(
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ url_cache=bool(media_info.url_cache),
+ )
+
+ if file_path:
+ await respond_with_file(request, desired_type, file_path)
+ else:
+ logger.warning("Failed to generate thumbnail")
+ raise SynapseError(400, "Failed to generate thumbnail.")
+
+ async def select_or_generate_remote_thumbnail(
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ max_timeout_ms: int,
+ ) -> None:
+ media_info = await self.media_repo.get_remote_media_info(
+ server_name, media_id, max_timeout_ms
+ )
+ if not media_info:
+ respond_404(request)
+ return
+
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
+ server_name, media_id
+ )
+
+ file_id = media_info.filesystem_id
+
+ for info in thumbnail_infos:
+ t_w = info.width == desired_width
+ t_h = info.height == desired_height
+ t_method = info.method == desired_method
+ t_type = info.type == desired_type
+
+ if t_w and t_h and t_method and t_type:
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=info,
+ )
+
+ responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request, responder, info.type, info.length
+ )
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
+
+ # Okay, so we generate one.
+ file_path = await self.media_repo.generate_remote_exact_thumbnail(
+ server_name,
+ file_id,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ )
+
+ if file_path:
+ await respond_with_file(request, desired_type, file_path)
+ else:
+ logger.warning("Failed to generate thumbnail")
+ raise SynapseError(400, "Failed to generate thumbnail.")
+
+ async def respond_remote_thumbnail(
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ max_timeout_ms: int,
+ ) -> None:
+ # TODO: Don't download the whole remote file
+ # We should proxy the thumbnail from the remote server instead of
+ # downloading the remote file and generating our own thumbnails.
+ media_info = await self.media_repo.get_remote_media_info(
+ server_name, media_id, max_timeout_ms
+ )
+ if not media_info:
+ return
+
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
+ server_name, media_id
+ )
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ media_info.filesystem_id,
+ url_cache=False,
+ server_name=server_name,
+ )
+
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: SynapseRequest,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[ThumbnailInfo],
+ media_id: str,
+ file_id: str,
+ url_cache: bool,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of thumbnail info of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: True if this is from a URL cache.
+ server_name: The server name, if this is a remote thumbnail.
+ """
+ logger.debug(
+ "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ thumbnail_infos,
+ )
+
+ # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
+ # different code path to handle it.
+ assert not self.dynamic_thumbnails
+
+ if thumbnail_infos:
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
+ )
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
+
+ # The thumbnail property must exist.
+ assert file_info.thumbnail is not None
+
+ responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request,
+ responder,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
+ )
+ return
+
+ # If we can't find the thumbnail we regenerate it. This can happen
+ # if e.g. we've deleted the thumbnails but still have the original
+ # image somewhere.
+ #
+ # Since we have an entry for the thumbnail in the DB we a) know we
+ # have have successfully generated the thumbnail in the past (so we
+ # don't need to worry about repeatedly failing to generate
+ # thumbnails), and b) have already calculated that appropriate
+ # width/height/method so we can just call the "generate exact"
+ # methods.
+
+ # First let's check that we do actually have the original image
+ # still. This will throw a 404 if we don't.
+ # TODO: We should refetch the thumbnails for remote media.
+ await self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=url_cache)
+ )
+
+ if server_name:
+ await self.media_repo.generate_remote_exact_thumbnail(
+ server_name,
+ file_id=file_id,
+ media_id=media_id,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
+ )
+ else:
+ await self.media_repo.generate_local_exact_thumbnail(
+ media_id=media_id,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
+ url_cache=url_cache,
+ )
+
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(
+ request,
+ responder,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
+ )
+ else:
+ # This might be because:
+ # 1. We can't create thumbnails for the given media (corrupted or
+ # unsupported file type), or
+ # 2. The thumbnailing process never ran or errored out initially
+ # when the media was first uploaded (these bugs should be
+ # reported and fixed).
+ # Note that we don't attempt to generate a thumbnail now because
+ # `dynamic_thumbnails` is disabled.
+ logger.info("Failed to find any generated thumbnails")
+
+ assert request.path is not None
+ respond_with_json(
+ request,
+ 400,
+ cs_error(
+ "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+ % (
+ request.path.decode(),
+ ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
+ ),
+ code=Codes.UNKNOWN,
+ ),
+ send_cors=True,
+ )
+
+ def _select_thumbnail(
+ self,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[ThumbnailInfo],
+ file_id: str,
+ url_cache: bool,
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: True if this is from a URL cache.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
+ d_w = desired_width
+ d_h = desired_height
+
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
+ crop_info_list: List[
+ Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+ ] = []
+ # Other thumbnails.
+ crop_info_list2: List[
+ Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+ ] = []
+ for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info.method != "crop":
+ continue
+
+ t_w = info.width
+ t_h = info.height
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info.type
+ length_quality = info.length
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info and pick the thumbnail that appears earlier
+ # in the list of candidates.
+ if crop_info_list:
+ thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
+ info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
+ # Other thumbnails.
+ info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
+
+ for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info.method != "scale":
+ continue
+
+ t_w = info.width
+ t_h = info.height
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info.type
+ length_quality = info.length
+ if t_w >= d_w or t_h >= d_h:
+ info_list.append((size_quality, type_quality, length_quality, info))
+ else:
+ info_list2.append(
+ (size_quality, type_quality, length_quality, info)
+ )
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info and pick the thumbnail that appears earlier
+ # in the list of candidates.
+ if info_list:
+ thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=thumbnail_info,
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 5e5387fdcb..2d6d49eed7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -112,6 +112,15 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
+ all_room_ids: Set[str] = set()
+ if stream_name == DeviceListsStream.NAME:
+ if any(row.entity.startswith("@") and not row.is_signature for row in rows):
+ prev_token = self.store.get_device_stream_token()
+ all_room_ids = await self.store.get_all_device_list_changes(
+ prev_token, token
+ )
+ self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
+
self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
@@ -146,12 +155,6 @@ class ReplicationDataHandler:
StreamKeyType.TO_DEVICE, token, users=entities
)
elif stream_name == DeviceListsStream.NAME:
- all_room_ids: Set[str] = set()
- for row in rows:
- if row.entity.startswith("@") and not row.is_signature:
- room_ids = await self.store.get_rooms_for_user(row.entity)
- all_room_ids.update(room_ids)
-
# `all_room_ids` can be large, so let's wake up those streams in batches
for batched_room_ids in batch_iter(all_room_ids, 100):
self.notifier.on_new_event(
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
new file mode 100644
index 0000000000..172d240783
--- /dev/null
+++ b/synapse/rest/client/media.py
@@ -0,0 +1,205 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+
+import logging
+import re
+
+from synapse.http.server import (
+ HttpServer,
+ respond_with_json,
+ respond_with_json_bytes,
+ set_corp_headers,
+ set_cors_headers,
+)
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.media._base import (
+ DEFAULT_MAX_TIMEOUT_MS,
+ MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
+ respond_404,
+)
+from synapse.media.media_repository import MediaRepository
+from synapse.media.media_storage import MediaStorage
+from synapse.media.thumbnailer import ThumbnailProvider
+from synapse.server import HomeServer
+from synapse.util.stringutils import parse_and_validate_server_name
+
+logger = logging.getLogger(__name__)
+
+
+class UnstablePreviewURLServlet(RestServlet):
+ """
+ Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
+ for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
+ specific additions).
+
+ This does have trade-offs compared to other designs:
+
+ * Pros:
+ * Simple and flexible; can be used by any clients at any point
+ * Cons:
+ * If each homeserver provides one of these independently, all the homeservers in a
+ room may needlessly DoS the target URI
+ * The URL metadata must be stored somewhere, rather than just using Matrix
+ itself to store the media.
+ * Matrix cannot be used to distribute the metadata between homeservers.
+ """
+
+ PATTERNS = [
+ re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
+ ]
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
+ super().__init__()
+
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.media_repo = media_repo
+ self.media_storage = media_storage
+ assert self.media_repo.url_previewer is not None
+ self.url_previewer = self.media_repo.url_previewer
+
+ async def on_GET(self, request: SynapseRequest) -> None:
+ requester = await self.auth.get_user_by_req(request)
+ url = parse_string(request, "url", required=True)
+ ts = parse_integer(request, "ts")
+ if ts is None:
+ ts = self.clock.time_msec()
+
+ og = await self.url_previewer.preview(url, requester.user, ts)
+ respond_with_json_bytes(request, 200, og, send_cors=True)
+
+
+class UnstableMediaConfigResource(RestServlet):
+ PATTERNS = [
+ re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
+ ]
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ config = hs.config
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+ self.limits_dict = {"m.upload.size": config.media.max_upload_size}
+
+ async def on_GET(self, request: SynapseRequest) -> None:
+ await self.auth.get_user_by_req(request)
+ respond_with_json(request, 200, self.limits_dict, send_cors=True)
+
+
+class UnstableThumbnailResource(RestServlet):
+ PATTERNS = [
+ re.compile(
+ "/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
+ )
+ ]
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
+ super().__init__()
+
+ self.store = hs.get_datastores().main
+ self.media_repo = media_repo
+ self.media_storage = media_storage
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+ self._is_mine_server_name = hs.is_mine_server_name
+ self._server_name = hs.hostname
+ self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+ self.thumbnailer = ThumbnailProvider(hs, media_repo, media_storage)
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self, request: SynapseRequest, server_name: str, media_id: str
+ ) -> None:
+ # Validate the server name, raising if invalid
+ parse_and_validate_server_name(server_name)
+ await self.auth.get_user_by_req(request)
+
+ set_cors_headers(request)
+ set_corp_headers(request)
+ width = parse_integer(request, "width", required=True)
+ height = parse_integer(request, "height", required=True)
+ method = parse_string(request, "method", "scale")
+ # TODO Parse the Accept header to get an prioritised list of thumbnail types.
+ m_type = "image/png"
+ max_timeout_ms = parse_integer(
+ request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+ )
+ max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+
+ if self._is_mine_server_name(server_name):
+ if self.dynamic_thumbnails:
+ await self.thumbnailer.select_or_generate_local_thumbnail(
+ request, media_id, width, height, method, m_type, max_timeout_ms
+ )
+ else:
+ await self.thumbnailer.respond_local_thumbnail(
+ request, media_id, width, height, method, m_type, max_timeout_ms
+ )
+ self.media_repo.mark_recently_accessed(None, media_id)
+ else:
+ # Don't let users download media from configured domains, even if it
+ # is already downloaded. This is Trust & Safety tooling to make some
+ # media inaccessible to local users.
+ # See `prevent_media_downloads_from` config docs for more info.
+ if server_name in self.prevent_media_downloads_from:
+ respond_404(request)
+ return
+
+ remote_resp_function = (
+ self.thumbnailer.select_or_generate_remote_thumbnail
+ if self.dynamic_thumbnails
+ else self.thumbnailer.respond_remote_thumbnail
+ )
+ await remote_resp_function(
+ request,
+ server_name,
+ media_id,
+ width,
+ height,
+ method,
+ m_type,
+ max_timeout_ms,
+ )
+ self.media_repo.mark_recently_accessed(server_name, media_id)
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+ if hs.config.experimental.msc3916_authenticated_media_enabled:
+ media_repo = hs.get_media_repository()
+ if hs.config.media.url_preview_enabled:
+ UnstablePreviewURLServlet(
+ hs, media_repo, media_repo.media_storage
+ ).register(http_server)
+ UnstableMediaConfigResource(hs).register(http_server)
+ UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+ http_server
+ )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 4a57eaf930..27ea943e31 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
return result
+class SlidingSyncE2eeRestServlet(RestServlet):
+ """
+ API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
+ of Sliding Sync but doesn't have any sliding window component. It's just a way to
+ get E2EE events without having to sit through a big initial sync (`/sync` v2). And
+ we can avoid encryption events being backed up by the main sync response.
+
+ Having To-Device messages split out to this sync endpoint also helps when clients
+ need to have 2 or more sync streams open at a time, e.g a push notification process
+ and a main process. This can cause the two processes to race to fetch the To-Device
+ events, resulting in the need for complex synchronisation rules to ensure the token
+ is correctly and atomically exchanged between processes.
+
+ GET parameters::
+ timeout(int): How long to wait for new events in milliseconds.
+ since(batch_token): Batch token when asking for incremental deltas.
+
+ Response JSON::
+ {
+ "next_batch": // batch token for the next /sync
+ "to_device": {
+ // list of to-device events
+ "events": [
+ {
+ "content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
+ "type": "m.room.encrypted",
+ "sender": "@alice:example.com",
+ }
+ // ...
+ ]
+ },
+ "device_lists": {
+ "changed": ["@alice:example.com"],
+ "left": ["@bob:example.com"]
+ },
+ "device_one_time_keys_count": {
+ "signed_curve25519": 50
+ },
+ "device_unused_fallback_key_types": [
+ "signed_curve25519"
+ ]
+ }
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self.sync_handler = hs.get_sync_handler()
+
+ # Filtering only matters for the `device_lists` because it requires a bunch of
+ # derived information from rooms (see how `_generate_sync_entry_for_rooms()`
+ # prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
+ self.only_member_events_filter_collection = FilterCollection(
+ self.hs,
+ {
+ "room": {
+ # We only care about membership events for the `device_lists`.
+ # Membership will tell us whether a user has joined/left a room and
+ # if there are new devices to encrypt for.
+ "timeline": {
+ "types": ["m.room.member"],
+ },
+ "state": {
+ "types": ["m.room.member"],
+ },
+ # We don't want any extra account_data generated because it's not
+ # returned by this endpoint. This helps us avoid work in
+ # `_generate_sync_entry_for_rooms()`
+ "account_data": {
+ "not_types": ["*"],
+ },
+ # We don't want any extra ephemeral data generated because it's not
+ # returned by this endpoint. This helps us avoid work in
+ # `_generate_sync_entry_for_rooms()`
+ "ephemeral": {
+ "not_types": ["*"],
+ },
+ },
+ # We don't want any extra account_data generated because it's not
+ # returned by this endpoint. (This is just here for good measure)
+ "account_data": {
+ "not_types": ["*"],
+ },
+ # We don't want any extra presence data generated because it's not
+ # returned by this endpoint. (This is just here for good measure)
+ "presence": {
+ "not_types": ["*"],
+ },
+ },
+ )
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user = requester.user
+ device_id = requester.device_id
+
+ timeout = parse_integer(request, "timeout", default=0)
+ since = parse_string(request, "since")
+
+ sync_config = SyncConfig(
+ user=user,
+ filter_collection=self.only_member_events_filter_collection,
+ is_guest=requester.is_guest,
+ device_id=device_id,
+ )
+
+ since_token = None
+ if since is not None:
+ since_token = await StreamToken.from_string(self.store, since)
+
+ # Request cache key
+ request_key = (
+ SyncVersion.E2EE_SYNC,
+ user,
+ timeout,
+ since,
+ )
+
+ # Gather data for the response
+ sync_result = await self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config,
+ SyncVersion.E2EE_SYNC,
+ request_key,
+ since_token=since_token,
+ timeout=timeout,
+ full_state=False,
+ )
+
+ # The client may have disconnected by now; don't bother to serialize the
+ # response if so.
+ if request._disconnected:
+ logger.info("Client has disconnected; not serializing response.")
+ return 200, {}
+
+ response: JsonDict = defaultdict(dict)
+ response["next_batch"] = await sync_result.next_batch.to_string(self.store)
+
+ if sync_result.to_device:
+ response["to_device"] = {"events": sync_result.to_device}
+
+ if sync_result.device_lists.changed:
+ response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
+ if sync_result.device_lists.left:
+ response["device_lists"]["left"] = list(sync_result.device_lists.left)
+
+ # We always include this because https://github.com/vector-im/element-android/issues/3725
+ # The spec isn't terribly clear on when this can be omitted and how a client would tell
+ # the difference between "no keys present" and "nothing changed" in terms of whole field
+ # absent / individual key type entry absent
+ # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
+ response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
+
+ # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+ # states that this field should always be included, as long as the server supports the feature.
+ response["device_unused_fallback_key_types"] = (
+ sync_result.device_unused_fallback_key_types
+ )
+
+ return 200, response
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
+
+ if hs.config.experimental.msc3575_enabled:
+ SlidingSyncE2eeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 7cb335c7c3..fe8fbb06e4 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -22,23 +22,18 @@
import logging
import re
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING
-from synapse.api.errors import Codes, SynapseError, cs_error
-from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
-from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
+from synapse.http.server import set_corp_headers, set_cors_headers
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
- FileInfo,
- ThumbnailInfo,
respond_404,
- respond_with_file,
- respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
+from synapse.media.thumbnailer import ThumbnailProvider
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
@@ -66,10 +61,11 @@ class ThumbnailResource(RestServlet):
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
- self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+ self.thumbnail_provider = ThumbnailProvider(hs, media_repo, media_storage)
async def on_GET(
self, request: SynapseRequest, server_name: str, media_id: str
@@ -91,11 +87,11 @@ class ThumbnailResource(RestServlet):
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
- await self._select_or_generate_local_thumbnail(
+ await self.thumbnail_provider.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
else:
- await self._respond_local_thumbnail(
+ await self.thumbnail_provider.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
self.media_repo.mark_recently_accessed(None, media_id)
@@ -109,9 +105,9 @@ class ThumbnailResource(RestServlet):
return
remote_resp_function = (
- self._select_or_generate_remote_thumbnail
+ self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
- else self._respond_remote_thumbnail
+ else self.thumbnail_provider.respond_remote_thumbnail
)
await remote_resp_function(
request,
@@ -124,457 +120,3 @@ class ThumbnailResource(RestServlet):
max_timeout_ms,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
-
- async def _respond_local_thumbnail(
- self,
- request: SynapseRequest,
- media_id: str,
- width: int,
- height: int,
- method: str,
- m_type: str,
- max_timeout_ms: int,
- ) -> None:
- media_info = await self.media_repo.get_local_media_info(
- request, media_id, max_timeout_ms
- )
- if not media_info:
- return
-
- thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
- await self._select_and_respond_with_thumbnail(
- request,
- width,
- height,
- method,
- m_type,
- thumbnail_infos,
- media_id,
- media_id,
- url_cache=bool(media_info.url_cache),
- server_name=None,
- )
-
- async def _select_or_generate_local_thumbnail(
- self,
- request: SynapseRequest,
- media_id: str,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- max_timeout_ms: int,
- ) -> None:
- media_info = await self.media_repo.get_local_media_info(
- request, media_id, max_timeout_ms
- )
-
- if not media_info:
- return
-
- thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
- for info in thumbnail_infos:
- t_w = info.width == desired_width
- t_h = info.height == desired_height
- t_method = info.method == desired_method
- t_type = info.type == desired_type
-
- if t_w and t_h and t_method and t_type:
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=bool(media_info.url_cache),
- thumbnail=info,
- )
-
- responder = await self.media_storage.fetch_media(file_info)
- if responder:
- await respond_with_responder(
- request, responder, info.type, info.length
- )
- return
-
- logger.debug("We don't have a thumbnail of that size. Generating")
-
- # Okay, so we generate one.
- file_path = await self.media_repo.generate_local_exact_thumbnail(
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- url_cache=bool(media_info.url_cache),
- )
-
- if file_path:
- await respond_with_file(request, desired_type, file_path)
- else:
- logger.warning("Failed to generate thumbnail")
- raise SynapseError(400, "Failed to generate thumbnail.")
-
- async def _select_or_generate_remote_thumbnail(
- self,
- request: SynapseRequest,
- server_name: str,
- media_id: str,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- max_timeout_ms: int,
- ) -> None:
- media_info = await self.media_repo.get_remote_media_info(
- server_name, media_id, max_timeout_ms
- )
- if not media_info:
- respond_404(request)
- return
-
- thumbnail_infos = await self.store.get_remote_media_thumbnails(
- server_name, media_id
- )
-
- file_id = media_info.filesystem_id
-
- for info in thumbnail_infos:
- t_w = info.width == desired_width
- t_h = info.height == desired_height
- t_method = info.method == desired_method
- t_type = info.type == desired_type
-
- if t_w and t_h and t_method and t_type:
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- thumbnail=info,
- )
-
- responder = await self.media_storage.fetch_media(file_info)
- if responder:
- await respond_with_responder(
- request, responder, info.type, info.length
- )
- return
-
- logger.debug("We don't have a thumbnail of that size. Generating")
-
- # Okay, so we generate one.
- file_path = await self.media_repo.generate_remote_exact_thumbnail(
- server_name,
- file_id,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- )
-
- if file_path:
- await respond_with_file(request, desired_type, file_path)
- else:
- logger.warning("Failed to generate thumbnail")
- raise SynapseError(400, "Failed to generate thumbnail.")
-
- async def _respond_remote_thumbnail(
- self,
- request: SynapseRequest,
- server_name: str,
- media_id: str,
- width: int,
- height: int,
- method: str,
- m_type: str,
- max_timeout_ms: int,
- ) -> None:
- # TODO: Don't download the whole remote file
- # We should proxy the thumbnail from the remote server instead of
- # downloading the remote file and generating our own thumbnails.
- media_info = await self.media_repo.get_remote_media_info(
- server_name, media_id, max_timeout_ms
- )
- if not media_info:
- return
-
- thumbnail_infos = await self.store.get_remote_media_thumbnails(
- server_name, media_id
- )
- await self._select_and_respond_with_thumbnail(
- request,
- width,
- height,
- method,
- m_type,
- thumbnail_infos,
- media_id,
- media_info.filesystem_id,
- url_cache=False,
- server_name=server_name,
- )
-
- async def _select_and_respond_with_thumbnail(
- self,
- request: SynapseRequest,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- thumbnail_infos: List[ThumbnailInfo],
- media_id: str,
- file_id: str,
- url_cache: bool,
- server_name: Optional[str] = None,
- ) -> None:
- """
- Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
-
- Args:
- request: The incoming request.
- desired_width: The desired width, the returned thumbnail may be larger than this.
- desired_height: The desired height, the returned thumbnail may be larger than this.
- desired_method: The desired method used to generate the thumbnail.
- desired_type: The desired content-type of the thumbnail.
- thumbnail_infos: A list of thumbnail info of candidate thumbnails.
- file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: True if this is from a URL cache.
- server_name: The server name, if this is a remote thumbnail.
- """
- logger.debug(
- "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
- media_id,
- desired_width,
- desired_height,
- desired_method,
- thumbnail_infos,
- )
-
- # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
- # different code path to handle it.
- assert not self.dynamic_thumbnails
-
- if thumbnail_infos:
- file_info = self._select_thumbnail(
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- thumbnail_infos,
- file_id,
- url_cache,
- server_name,
- )
- if not file_info:
- logger.info("Couldn't find a thumbnail matching the desired inputs")
- respond_404(request)
- return
-
- # The thumbnail property must exist.
- assert file_info.thumbnail is not None
-
- responder = await self.media_storage.fetch_media(file_info)
- if responder:
- await respond_with_responder(
- request,
- responder,
- file_info.thumbnail.type,
- file_info.thumbnail.length,
- )
- return
-
- # If we can't find the thumbnail we regenerate it. This can happen
- # if e.g. we've deleted the thumbnails but still have the original
- # image somewhere.
- #
- # Since we have an entry for the thumbnail in the DB we a) know we
- # have have successfully generated the thumbnail in the past (so we
- # don't need to worry about repeatedly failing to generate
- # thumbnails), and b) have already calculated that appropriate
- # width/height/method so we can just call the "generate exact"
- # methods.
-
- # First let's check that we do actually have the original image
- # still. This will throw a 404 if we don't.
- # TODO: We should refetch the thumbnails for remote media.
- await self.media_storage.ensure_media_is_in_local_cache(
- FileInfo(server_name, file_id, url_cache=url_cache)
- )
-
- if server_name:
- await self.media_repo.generate_remote_exact_thumbnail(
- server_name,
- file_id=file_id,
- media_id=media_id,
- t_width=file_info.thumbnail.width,
- t_height=file_info.thumbnail.height,
- t_method=file_info.thumbnail.method,
- t_type=file_info.thumbnail.type,
- )
- else:
- await self.media_repo.generate_local_exact_thumbnail(
- media_id=media_id,
- t_width=file_info.thumbnail.width,
- t_height=file_info.thumbnail.height,
- t_method=file_info.thumbnail.method,
- t_type=file_info.thumbnail.type,
- url_cache=url_cache,
- )
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(
- request,
- responder,
- file_info.thumbnail.type,
- file_info.thumbnail.length,
- )
- else:
- # This might be because:
- # 1. We can't create thumbnails for the given media (corrupted or
- # unsupported file type), or
- # 2. The thumbnailing process never ran or errored out initially
- # when the media was first uploaded (these bugs should be
- # reported and fixed).
- # Note that we don't attempt to generate a thumbnail now because
- # `dynamic_thumbnails` is disabled.
- logger.info("Failed to find any generated thumbnails")
-
- assert request.path is not None
- respond_with_json(
- request,
- 400,
- cs_error(
- "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
- % (
- request.path.decode(),
- ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
- ),
- code=Codes.UNKNOWN,
- ),
- send_cors=True,
- )
-
- def _select_thumbnail(
- self,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- thumbnail_infos: List[ThumbnailInfo],
- file_id: str,
- url_cache: bool,
- server_name: Optional[str],
- ) -> Optional[FileInfo]:
- """
- Choose an appropriate thumbnail from the previously generated thumbnails.
-
- Args:
- desired_width: The desired width, the returned thumbnail may be larger than this.
- desired_height: The desired height, the returned thumbnail may be larger than this.
- desired_method: The desired method used to generate the thumbnail.
- desired_type: The desired content-type of the thumbnail.
- thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
- file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: True if this is from a URL cache.
- server_name: The server name, if this is a remote thumbnail.
-
- Returns:
- The thumbnail which best matches the desired parameters.
- """
- desired_method = desired_method.lower()
-
- # The chosen thumbnail.
- thumbnail_info = None
-
- d_w = desired_width
- d_h = desired_height
-
- if desired_method == "crop":
- # Thumbnails that match equal or larger sizes of desired width/height.
- crop_info_list: List[
- Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
- ] = []
- # Other thumbnails.
- crop_info_list2: List[
- Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
- ] = []
- for info in thumbnail_infos:
- # Skip thumbnails generated with different methods.
- if info.method != "crop":
- continue
-
- t_w = info.width
- t_h = info.height
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info.type
- length_quality = info.length
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
- )
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
- )
- # Pick the most appropriate thumbnail. Some values of `desired_width` and
- # `desired_height` may result in a tie, in which case we avoid comparing on
- # the thumbnail info and pick the thumbnail that appears earlier
- # in the list of candidates.
- if crop_info_list:
- thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
- elif crop_info_list2:
- thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
- elif desired_method == "scale":
- # Thumbnails that match equal or larger sizes of desired width/height.
- info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
- # Other thumbnails.
- info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
-
- for info in thumbnail_infos:
- # Skip thumbnails generated with different methods.
- if info.method != "scale":
- continue
-
- t_w = info.width
- t_h = info.height
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info.type
- length_quality = info.length
- if t_w >= d_w or t_h >= d_h:
- info_list.append((size_quality, type_quality, length_quality, info))
- else:
- info_list2.append(
- (size_quality, type_quality, length_quality, info)
- )
- # Pick the most appropriate thumbnail. Some values of `desired_width` and
- # `desired_height` may result in a tie, in which case we avoid comparing on
- # the thumbnail info and pick the thumbnail that appears earlier
- # in the list of candidates.
- if info_list:
- thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
- elif info_list2:
- thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
-
- if thumbnail_info:
- return FileInfo(
- file_id=file_id,
- url_cache=url_cache,
- server_name=server_name,
- thumbnail=thumbnail_info,
- )
-
- # No matching thumbnail was found.
- return None
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8dbcb3f5a0..48384e238c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -70,10 +70,7 @@ from synapse.types import (
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import (
- AllEntitiesChangedResult,
- StreamChangeCache,
-)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=device_list_prefill,
)
+ device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_changes_in_room",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_room_stream_cache = StreamChangeCache(
+ "DeviceListRoomStreamChangeCache",
+ min_device_list_room_id,
+ prefilled_cache=device_list_room_prefill,
+ )
+
(
user_signature_stream_prefill,
user_signature_stream_list_id,
@@ -209,6 +220,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
row.entity, token
)
+ def device_lists_in_rooms_have_changed(
+ self, room_ids: StrCollection, token: int
+ ) -> None:
+ "Record that device lists have changed in rooms"
+ for room_id in room_ids:
+ self._device_list_room_stream_cache.entity_has_changed(room_id, token)
+
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -832,16 +850,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
return {device[0]: db_to_json(device[1]) for device in devices}
- def get_cached_device_list_changes(
- self,
- from_key: int,
- ) -> AllEntitiesChangedResult:
- """Get set of users whose devices have changed since `from_key`, or None
- if that information is not in our cache.
- """
-
- return self._device_list_stream_cache.get_all_entities_changed(from_key)
-
@cancellable
async def get_all_devices_changed(
self,
@@ -1457,7 +1465,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_device_list_changes_in_rooms(
- self, room_ids: Collection[str], from_id: int
+ self, room_ids: Collection[str], from_id: int, to_id: int
) -> Optional[Set[str]]:
"""Return the set of users whose devices have changed in the given rooms
since the given stream ID.
@@ -1473,9 +1481,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if min_stream_id > from_id:
return None
+ changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
+ room_ids, from_id
+ )
+ if not changed_room_ids:
+ return set()
+
sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
- WHERE {clause} AND stream_id >= ?
+ WHERE {clause} AND stream_id > ? AND stream_id <= ?
"""
def _get_device_list_changes_in_rooms_txn(
@@ -1487,11 +1501,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {user_id for user_id, in txn}
changes = set()
- for chunk in batch_iter(room_ids, 1000):
+ for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
args.append(from_id)
+ args.append(to_id)
changes |= await self.db_pool.runInteraction(
"get_device_list_changes_in_rooms",
@@ -1502,6 +1517,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes
+ async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
+ """Return the set of rooms where devices have changed since the given
+ stream ID.
+
+ Will raise an exception if the given stream ID is too old.
+ """
+
+ min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+ if min_stream_id > from_id:
+ raise Exception("stream ID is too old")
+
+ sql = """
+ SELECT DISTINCT room_id FROM device_lists_changes_in_room
+ WHERE stream_id > ? AND stream_id <= ?
+ """
+
+ def _get_all_device_list_changes_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
+ txn.execute(sql, (from_id, to_id))
+ return {room_id for room_id, in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_all_device_list_changes",
+ _get_all_device_list_changes_txn,
+ )
+
async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
@@ -1962,8 +2005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self,
user_id: str,
- device_ids: Collection[str],
- room_ids: Collection[str],
+ device_ids: StrCollection,
+ room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -2118,12 +2161,36 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
},
)
+ async def mark_redundant_device_lists_pokes(
+ self,
+ user_id: str,
+ device_id: str,
+ room_id: str,
+ converted_upto_stream_id: int,
+ ) -> None:
+ """If we've calculated the outbound pokes for a given room/device list
+ update, mark any subsequent changes as already converted"""
+
+ sql = """
+ UPDATE device_lists_changes_in_room
+ SET converted_to_destinations = true
+ WHERE stream_id > ? AND user_id = ? AND device_id = ?
+ AND room_id = ? AND NOT converted_to_destinations
+ """
+
+ def mark_redundant_device_lists_pokes_txn(txn: LoggingTransaction) -> None:
+ txn.execute(sql, (converted_upto_stream_id, user_id, device_id, room_id))
+
+ return await self.db_pool.runInteraction(
+ "mark_redundant_device_lists_pokes", mark_redundant_device_lists_pokes_txn
+ )
+
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
- device_ids: Iterable[str],
- room_ids: Collection[str],
+ device_ids: StrCollection,
+ room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
@@ -2161,6 +2228,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
+ txn.call_after(
+ self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
+ )
+
async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 01d05c9ed6..448960b297 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
from twisted.python.failure import Failure
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+ ContextResourceUsage,
+ LoggingContext,
+ nested_logging_context,
+ set_current_context,
+)
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
@@ -81,6 +86,8 @@ class TaskScheduler:
MAX_CONCURRENT_RUNNING_TASKS = 5
# Time from the last task update after which we will log a warning
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
+ # Report a running task's status and usage every so often.
+ OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes
def __init__(self, hs: "HomeServer"):
self._hs = hs
@@ -346,6 +353,33 @@ class TaskScheduler:
assert task.id not in self._running_tasks
await self._store.delete_scheduled_task(task.id)
+ @staticmethod
+ def _log_task_usage(
+ state: str, task: ScheduledTask, usage: ContextResourceUsage, active_time: float
+ ) -> None:
+ """
+ Log a line describing the state and usage of a task.
+ The log line is inspired by / a copy of the request log line format,
+ but with irrelevant fields removed.
+
+ active_time: Time that the task has been running for, in seconds.
+ """
+
+ logger.info(
+ "Task %s: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
+ " [%d dbevts] %r, %r",
+ state,
+ active_time,
+ usage.ru_utime,
+ usage.ru_stime,
+ usage.db_sched_duration_sec,
+ usage.db_txn_duration_sec,
+ int(usage.db_txn_count),
+ usage.evt_db_fetch_count,
+ task.resource_id,
+ task.params,
+ )
+
async def _launch_task(self, task: ScheduledTask) -> None:
"""Launch a scheduled task now.
@@ -360,8 +394,32 @@ class TaskScheduler:
)
function = self._actions[task.action]
+ def _occasional_report(
+ task_log_context: LoggingContext, start_time: float
+ ) -> None:
+ """
+ Helper to log a 'Task continuing' line every so often.
+ """
+
+ current_time = self._clock.time()
+ calling_context = set_current_context(task_log_context)
+ try:
+ usage = task_log_context.get_resource_usage()
+ TaskScheduler._log_task_usage(
+ "continuing", task, usage, current_time - start_time
+ )
+ finally:
+ set_current_context(calling_context)
+
async def wrapper() -> None:
- with nested_logging_context(task.id):
+ with nested_logging_context(task.id) as log_context:
+ start_time = self._clock.time()
+ occasional_status_call = self._clock.looping_call(
+ _occasional_report,
+ TaskScheduler.OCCASIONAL_REPORT_INTERVAL_MS,
+ log_context,
+ start_time,
+ )
try:
(status, result, error) = await function(task)
except Exception:
@@ -383,6 +441,13 @@ class TaskScheduler:
)
self._running_tasks.remove(task.id)
+ current_time = self._clock.time()
+ usage = log_context.get_resource_usage()
+ TaskScheduler._log_task_usage(
+ status.value, task, usage, current_time - start_time
+ )
+ occasional_status_call.stop()
+
# Try launch a new task since we've finished with this one.
self._clock.call_later(0.1, self._launch_scheduled_tasks)
diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py
new file mode 100644
index 0000000000..7fb4d4fa90
--- /dev/null
+++ b/tests/events/test_auto_accept_invites.py
@@ -0,0 +1,657 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import asyncio
+from asyncio import Future
+from http import HTTPStatus
+from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast
+from unittest.mock import Mock
+
+import attr
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.events.auto_accept_invites import InviteAutoAccepter
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import StreamToken, create_requester
+from synapse.util import Clock
+
+from tests.handlers.test_sync import generate_sync_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ TestCase,
+ override_config,
+)
+
+
+class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
+ """
+ Integration test cases for auto-accepting invites.
+ """
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ hs = self.setup_test_homeserver()
+ self.handler = hs.get_federation_handler()
+ self.store = hs.get_datastores().main
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sync_handler = self.hs.get_sync_handler()
+ self.module_api = hs.get_module_api()
+
+ @parameterized.expand(
+ [
+ [False],
+ [True],
+ ]
+ )
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ def test_auto_accept_invites(self, direct_room: bool) -> None:
+ """Test that a user automatically joins a room when invited, if the
+ module is enabled.
+ """
+ # A local user who sends an invite
+ inviting_user_id = self.register_user("inviter", "pass")
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ is_public=False,
+ tok=inviting_user_tok,
+ )
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ extra_data={"is_direct": direct_room},
+ )
+
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": False,
+ },
+ }
+ )
+ def test_module_not_enabled(self) -> None:
+ """Test that a user does not automatically join a room when invited,
+ if the module is not enabled.
+ """
+ # A local user who sends an invite
+ inviting_user_id = self.register_user("inviter", "pass")
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id, is_public=False, tok=inviting_user_tok
+ )
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ )
+
+ # Check that the invite receiving user has not automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 0)
+
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ def test_invite_from_remote_user(self) -> None:
+ """Test that an invite from a remote user results in the invited user
+ automatically joining the room.
+ """
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ # A local user who creates the room
+ creator_user_id = self.register_user("creator", "pass")
+ creator_user_tok = self.login("creator", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ room_id = self.helper.create_room_as(
+ room_creator=creator_user_id, tok=creator_user_tok
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": invited_user_id,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+ self.get_success(
+ self.handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ )
+ )
+
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+
+ @parameterized.expand(
+ [
+ [False, False],
+ [True, True],
+ ]
+ )
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ "only_for_direct_messages": True,
+ },
+ }
+ )
+ def test_accept_invite_direct_message(
+ self,
+ direct_room: bool,
+ expect_auto_join: bool,
+ ) -> None:
+ """Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still
+ automatically accepted. Otherwise they are rejected.
+ """
+ # A local user who sends an invite
+ inviting_user_id = self.register_user("inviter", "pass")
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ is_public=False,
+ tok=inviting_user_tok,
+ )
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ extra_data={"is_direct": direct_room},
+ )
+
+ if expect_auto_join:
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+ else:
+ # Check that the invite receiving user has not automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 0)
+
+ @parameterized.expand(
+ [
+ [False, True],
+ [True, False],
+ ]
+ )
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ "only_from_local_users": True,
+ },
+ }
+ )
+ def test_accept_invite_local_user(
+ self, remote_inviter: bool, expect_auto_join: bool
+ ) -> None:
+ """Tests that, if the module is configured to only accept invites from local users, invites
+ from local users are still automatically accepted. Otherwise they are rejected.
+ """
+ # A local user who sends an invite
+ creator_user_id = self.register_user("inviter", "pass")
+ creator_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ creator_user_id, is_public=False, tok=creator_user_tok
+ )
+
+ if remote_inviter:
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": invited_user_id,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+ self.get_success(
+ self.handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ )
+ )
+ else:
+ self.helper.invite(
+ room_id,
+ creator_user_id,
+ invited_user_id,
+ tok=creator_user_tok,
+ )
+
+ if expect_auto_join:
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+ else:
+ # Check that the invite receiving user has not automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 0)
+
+
+_request_key = 0
+
+
+def generate_request_key() -> SyncRequestKey:
+ global _request_key
+ _request_key += 1
+ return ("request_key", _request_key)
+
+
+def sync_join(
+ testcase: HomeserverTestCase,
+ user_id: str,
+ since_token: Optional[StreamToken] = None,
+) -> Tuple[List[JoinedSyncResult], StreamToken]:
+ """Perform a sync request for the given user and return the user join updates
+ they've received, as well as the next_batch token.
+
+ This method assumes testcase.sync_handler points to the homeserver's sync handler.
+
+ Args:
+ testcase: The testcase that is currently being run.
+ user_id: The ID of the user to generate a sync response for.
+ since_token: An optional token to indicate from at what point to sync from.
+
+ Returns:
+ A tuple containing a list of join updates, and the sync response's
+ next_batch token.
+ """
+ requester = create_requester(user_id)
+ sync_config = generate_sync_config(requester.user.to_string())
+ sync_result = testcase.get_success(
+ testcase.hs.get_sync_handler().wait_for_sync_for_user(
+ requester,
+ sync_config,
+ SyncVersion.SYNC_V2,
+ generate_request_key(),
+ since_token,
+ )
+ )
+
+ return sync_result.joined, sync_result.next_batch
+
+
+class InviteAutoAccepterInternalTestCase(TestCase):
+ """
+ Test cases which exercise the internals of the InviteAutoAccepter.
+ """
+
+ def setUp(self) -> None:
+ self.module = create_module()
+ self.user_id = "@peter:test"
+ self.invitee = "@lesley:test"
+ self.remote_invitee = "@thomas:remote"
+
+ # We know our module API is a mock, but mypy doesn't.
+ self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment]
+
+ async def test_accept_invite_with_failures(self) -> None:
+ """Tests that receiving an invite for a local user makes the module attempt to
+ make the invitee join the room. This test verifies that it works if the call to
+ update membership returns exceptions before successfully completing and returning an event.
+ """
+ invite = MockEvent(
+ sender="@inviter:test",
+ state_key="@invitee:test",
+ type="m.room.member",
+ content={"membership": "invite"},
+ )
+
+ join_event = MockEvent(
+ sender="someone",
+ state_key="someone",
+ type="m.room.member",
+ content={"membership": "join"},
+ )
+ # the first two calls raise an exception while the third call is successful
+ self.mocked_update_membership.side_effect = [
+ SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
+ SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
+ make_awaitable(join_event),
+ ]
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ await self.retry_assertions(
+ self.mocked_update_membership,
+ 3,
+ sender=invite.state_key,
+ target=invite.state_key,
+ room_id=invite.room_id,
+ new_membership="join",
+ )
+
+ async def test_accept_invite_failures(self) -> None:
+ """Tests that receiving an invite for a local user makes the module attempt to
+ make the invitee join the room. This test verifies that if the update_membership call
+ fails consistently, _retry_make_join will break the loop after the set number of retries and
+ execution will continue.
+ """
+ invite = MockEvent(
+ sender=self.user_id,
+ state_key=self.invitee,
+ type="m.room.member",
+ content={"membership": "invite"},
+ )
+ self.mocked_update_membership.side_effect = SynapseError(
+ HTTPStatus.FORBIDDEN, "Forbidden"
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ await self.retry_assertions(
+ self.mocked_update_membership,
+ 5,
+ sender=invite.state_key,
+ target=invite.state_key,
+ room_id=invite.room_id,
+ new_membership="join",
+ )
+
+ async def test_not_state(self) -> None:
+ """Tests that receiving an invite that's not a state event does nothing."""
+ invite = MockEvent(
+ sender=self.user_id, type="m.room.member", content={"membership": "invite"}
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ self.mocked_update_membership.assert_not_called()
+
+ async def test_not_invite(self) -> None:
+ """Tests that receiving a membership update that's not an invite does nothing."""
+ invite = MockEvent(
+ sender=self.user_id,
+ state_key=self.user_id,
+ type="m.room.member",
+ content={"membership": "join"},
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ self.mocked_update_membership.assert_not_called()
+
+ async def test_not_membership(self) -> None:
+ """Tests that receiving a state event that's not a membership update does
+ nothing.
+ """
+ invite = MockEvent(
+ sender=self.user_id,
+ state_key=self.user_id,
+ type="org.matrix.test",
+ content={"foo": "bar"},
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ self.mocked_update_membership.assert_not_called()
+
+ def test_config_parse(self) -> None:
+ """Tests that a correct configuration parses."""
+ config = {
+ "auto_accept_invites": {
+ "enabled": True,
+ "only_for_direct_messages": True,
+ "only_from_local_users": True,
+ }
+ }
+ parsed_config = AutoAcceptInvitesConfig()
+ parsed_config.read_config(config)
+
+ self.assertTrue(parsed_config.enabled)
+ self.assertTrue(parsed_config.accept_invites_only_for_direct_messages)
+ self.assertTrue(parsed_config.accept_invites_only_from_local_users)
+
+ def test_runs_on_only_one_worker(self) -> None:
+ """
+ Tests that the module only runs on the specified worker.
+ """
+ # By default, we run on the main process...
+ main_module = create_module(
+ config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None
+ )
+ cast(
+ Mock, main_module._api.register_third_party_rules_callbacks
+ ).assert_called_once()
+
+ # ...and not on other workers (like synchrotrons)...
+ sync_module = create_module(worker_name="synchrotron42")
+ cast(
+ Mock, sync_module._api.register_third_party_rules_callbacks
+ ).assert_not_called()
+
+ # ...unless we configured them to be the designated worker.
+ specified_module = create_module(
+ config_override={
+ "auto_accept_invites": {
+ "enabled": True,
+ "worker_to_run_on": "account_data1",
+ }
+ },
+ worker_name="account_data1",
+ )
+ cast(
+ Mock, specified_module._api.register_third_party_rules_callbacks
+ ).assert_called_once()
+
+ async def retry_assertions(
+ self, mock: Mock, call_count: int, **kwargs: Any
+ ) -> None:
+ """
+ This is a hacky way to ensure that the assertions are not called before the other coroutine
+ has a chance to call `update_room_membership`. It catches the exception caused by a failure,
+ and sleeps the thread before retrying, up until 5 tries.
+
+ Args:
+ call_count: the number of times the mock should have been called
+ mock: the mocked function we want to assert on
+ kwargs: keyword arguments to assert that the mock was called with
+ """
+
+ i = 0
+ while i < 5:
+ try:
+ # Check that the mocked method is called the expected amount of times and with the right
+ # arguments to attempt to make the user join the room.
+ mock.assert_called_with(**kwargs)
+ self.assertEqual(call_count, mock.call_count)
+ break
+ except AssertionError as e:
+ i += 1
+ if i == 5:
+ # we've used up the tries, force the test to fail as we've already caught the exception
+ self.fail(e)
+ await asyncio.sleep(1)
+
+
+@attr.s(auto_attribs=True)
+class MockEvent:
+ """Mocks an event. Only exposes properties the module uses."""
+
+ sender: str
+ type: str
+ content: Dict[str, Any]
+ room_id: str = "!someroom"
+ state_key: Optional[str] = None
+
+ def is_state(self) -> bool:
+ """Checks if the event is a state event by checking if it has a state key."""
+ return self.state_key is not None
+
+ @property
+ def membership(self) -> str:
+ """Extracts the membership from the event. Should only be called on an event
+ that's a membership event, and will raise a KeyError otherwise.
+ """
+ membership: str = self.content["membership"]
+ return membership
+
+
+T = TypeVar("T")
+TV = TypeVar("TV")
+
+
+async def make_awaitable(value: T) -> T:
+ return value
+
+
+def make_multiple_awaitable(result: TV) -> Awaitable[TV]:
+ """
+ Makes an awaitable, suitable for mocking an `async` function.
+ This uses Futures as they can be awaited multiple times so can be returned
+ to multiple callers.
+ """
+ future: Future[TV] = Future()
+ future.set_result(result)
+ return future
+
+
+def create_module(
+ config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None
+) -> InviteAutoAccepter:
+ # Create a mock based on the ModuleApi spec, but override some mocked functions
+ # because some capabilities are needed for running the tests.
+ module_api = Mock(spec=ModuleApi)
+ module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
+ module_api.worker_name = worker_name
+ module_api.sleep.return_value = make_multiple_awaitable(None)
+
+ if config_override is None:
+ config_override = {}
+
+ config = AutoAcceptInvitesConfig()
+ config.read_config(config_override)
+
+ return InviteAutoAccepter(config, module_api)
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index cae67e11c8..1bd51ceba2 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -18,6 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import itertools
import os
import shutil
import tempfile
@@ -46,11 +47,11 @@ from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
from synapse.media.storage_provider import FileStorageProviderBackend
+from synapse.media.thumbnailer import ThumbnailProvider
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
from synapse.rest import admin
-from synapse.rest.client import login
-from synapse.rest.media.thumbnail_resource import ThumbnailResource
+from synapse.rest.client import login, media
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
@@ -153,68 +154,54 @@ class _TestImage:
is_inline: bool = True
-@parameterized_class(
- ("test_image",),
- [
- # small png
- (
- _TestImage(
- SMALL_PNG,
- b"image/png",
- b".png",
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000020000000200806"
- b"000000737a7af40000001a49444154789cedc101010000008220"
- b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
- b"44ae426082"
- ),
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000d49444154789c636060606000000005"
- b"0001a5f645400000000049454e44ae426082"
- ),
- ),
- ),
- # small png with transparency.
- (
- _TestImage(
- unhexlify(
- b"89504e470d0a1a0a0000000d49484452000000010000000101000"
- b"00000376ef9240000000274524e5300010194fdae0000000a4944"
- b"4154789c636800000082008177cd72b60000000049454e44ae426"
- b"082"
- ),
- b"image/png",
- b".png",
- # Note that we don't check the output since it varies across
- # different versions of Pillow.
- ),
- ),
- # small lossless webp
- (
- _TestImage(
- unhexlify(
- b"524946461a000000574542505650384c0d0000002f0000001007"
- b"1011118888fe0700"
- ),
- b"image/webp",
- b".webp",
- ),
- ),
- # an empty file
- (
- _TestImage(
- b"",
- b"image/gif",
- b".gif",
- expected_found=False,
- unable_to_thumbnail=True,
- ),
- ),
- # An SVG.
- (
- _TestImage(
- b"""<?xml version="1.0"?>
+small_png = _TestImage(
+ SMALL_PNG,
+ b"image/png",
+ b".png",
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000020000000200806"
+ b"000000737a7af40000001a49444154789cedc101010000008220"
+ b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
+ b"44ae426082"
+ ),
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000d49444154789c636060606000000005"
+ b"0001a5f645400000000049454e44ae426082"
+ ),
+)
+
+small_png_with_transparency = _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d49484452000000010000000101000"
+ b"00000376ef9240000000274524e5300010194fdae0000000a4944"
+ b"4154789c636800000082008177cd72b60000000049454e44ae426"
+ b"082"
+ ),
+ b"image/png",
+ b".png",
+ # Note that we don't check the output since it varies across
+ # different versions of Pillow.
+)
+
+small_lossless_webp = _TestImage(
+ unhexlify(
+ b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
+ ),
+ b"image/webp",
+ b".webp",
+)
+
+empty_file = _TestImage(
+ b"",
+ b"image/gif",
+ b".gif",
+ expected_found=False,
+ unable_to_thumbnail=True,
+)
+
+SVG = _TestImage(
+ b"""<?xml version="1.0"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
@@ -223,19 +210,32 @@ class _TestImage:
<circle cx="100" cy="100" r="50" stroke="black"
stroke-width="5" fill="red" />
</svg>""",
- b"image/svg",
- b".svg",
- expected_found=False,
- unable_to_thumbnail=True,
- is_inline=False,
- ),
- ),
- ],
+ b"image/svg",
+ b".svg",
+ expected_found=False,
+ unable_to_thumbnail=True,
+ is_inline=False,
)
+test_images = [
+ small_png,
+ small_png_with_transparency,
+ small_lossless_webp,
+ empty_file,
+ SVG,
+]
+urls = [
+ "_matrix/media/r0/thumbnail",
+ "_matrix/client/unstable/org.matrix.msc3916/media/thumbnail",
+]
+
+
+@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
class MediaRepoTests(unittest.HomeserverTestCase):
+ servlets = [media.register_servlets]
test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
+ url: ClassVar[str]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches: List[
@@ -298,6 +298,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
+ config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
@@ -502,7 +503,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale"
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
+ f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -530,7 +531,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
+ f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -566,12 +567,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method
channel = self.make_request(
"GET",
- f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
+ f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
self.pump()
-
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
b"Content-Type": [self.test_image.content_type],
@@ -580,7 +580,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
-
if expected_found:
self.assertEqual(channel.code, 200)
@@ -603,7 +602,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_UNKNOWN",
- "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
+ "error": f"Cannot find any thumbnails for the requested media ('/{self.url}/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
},
)
else:
@@ -613,7 +612,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_NOT_FOUND",
- "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
+ "error": f"Not found '/{self.url}/example.com/12345'",
},
)
@@ -625,12 +624,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
content_type = self.test_image.content_type.decode()
media_repo = self.hs.get_media_repository()
- thumbnail_resouce = ThumbnailResource(
+ thumbnail_provider = ThumbnailProvider(
self.hs, media_repo, media_repo.media_storage
)
self.assertIsNotNone(
- thumbnail_resouce._select_thumbnail(
+ thumbnail_provider._select_thumbnail(
desired_width=desired_size,
desired_height=desired_size,
desired_method=method,
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 2b360732ac..a3ed12a38f 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
-from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, keys, login, register
+from synapse.rest import admin, devices, sync
+from synapse.rest.client import keys, login, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -33,146 +33,6 @@ from synapse.util import Clock
from tests import unittest
-class DeviceListsTestCase(unittest.HomeserverTestCase):
- """Tests regarding device list changes."""
-
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- register.register_servlets,
- account.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- devices.register_servlets,
- ]
-
- def test_receiving_local_device_list_changes(self) -> None:
- """Tests that a local users that share a room receive each other's device list
- changes.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # Create a room for them to coexist peacefully in
- new_room_id = self.helper.create_room_as(
- alice_user_id, is_public=True, tok=alice_access_token
- )
- self.assertIsNotNone(new_room_id)
-
- # Have Bob join the room
- self.helper.invite(
- new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
- )
- self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
-
- # Now have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- "/sync",
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"/sync?since={next_batch_token}&timeout=30000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check that bob's incremental sync contains the updated device list.
- # If not, the client would only receive the device list update on the
- # *next* sync.
- bob_sync_channel.await_result()
- self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
-
- def test_not_receiving_local_device_list_changes(self) -> None:
- """Tests a local users DO NOT receive device updates from each other if they do not
- share a room.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # These users do not share a room. They are lonely.
-
- # Have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- "/sync",
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"/sync?since={next_batch_token}&timeout=1000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check that bob's incremental sync does not contain the updated device list.
- bob_sync_channel.await_result()
- self.assertEqual(
- bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
- )
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertNotIn(
- alice_user_id, changed_device_lists, bob_sync_channel.json_body
- )
-
-
class DevicesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
new file mode 100644
index 0000000000..600cbf8963
--- /dev/null
+++ b/tests/rest/client/test_media.py
@@ -0,0 +1,1609 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import base64
+import json
+import os
+import re
+from typing import Any, Dict, Optional, Sequence, Tuple, Type
+from urllib.parse import quote, urlencode
+
+from twisted.internet._resolver import HostResolution
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
+from twisted.web.resource import Resource
+
+from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.media._base import FileInfo
+from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
+from synapse.rest import admin
+from synapse.rest.client import login, media
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+from synapse.util.stringutils import parse_and_validate_mxc_uri
+
+from tests import unittest
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.test_utils import SMALL_PNG
+from tests.unittest import override_config
+
+try:
+ import lxml
+except ImportError:
+ lxml = None # type: ignore[assignment]
+
+
+class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
+ remote_media_id = "doesnotmatter"
+ remote_server_name = "evil.com"
+ servlets = [
+ media.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
+ config = self.default_config()
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ # Inject a piece of media. We'll use this to ensure we're returning a sane
+ # response when we're not supposed to block it, distinguishing a media block
+ # from a regular 404.
+ file_id = "abcdefg12345"
+ file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
+ with hs.get_media_repository().media_storage.store_into_file(file_info) as (
+ f,
+ fname,
+ finish,
+ ):
+ f.write(SMALL_PNG)
+ self.get_success(finish())
+
+ self.get_success(
+ self.store.store_cached_remote_media(
+ origin=self.remote_server_name,
+ media_id=self.remote_media_id,
+ media_type="image/png",
+ media_length=1,
+ time_now_ms=clock.time_msec(),
+ upload_name="test.png",
+ filesystem_id=file_id,
+ )
+ )
+ self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+
+ @override_config(
+ {
+ # Disable downloads from the domain we'll be trying to download from.
+ # Should result in a 404.
+ "prevent_media_downloads_from": ["evil.com"],
+ "dynamic_thumbnails": True,
+ "experimental_features": {"msc3916_authenticated_media_enabled": True},
+ }
+ )
+ def test_cannot_download_blocked_media_thumbnail(self) -> None:
+ """
+ Same test as test_cannot_download_blocked_media but for thumbnails.
+ """
+ response = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
+ shorthand=False,
+ content={"width": 100, "height": 100},
+ access_token=self.tok,
+ )
+ self.assertEqual(response.code, 404)
+
+ @override_config(
+ {
+ # Disable downloads from a domain we won't be requesting downloads from.
+ # This proves we haven't broken anything.
+ "prevent_media_downloads_from": ["not-listed.com"],
+ "dynamic_thumbnails": True,
+ "experimental_features": {"msc3916_authenticated_media_enabled": True},
+ }
+ )
+ def test_remote_media_thumbnail_normally_unblocked(self) -> None:
+ """
+ Same test as test_remote_media_normally_unblocked but for thumbnails.
+ """
+ response = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(response.code, 200)
+
+
+class UnstableURLPreviewTests(unittest.HomeserverTestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
+ servlets = [media.register_servlets]
+ hijack_auth = True
+ user_id = "@test:user"
+ end_content = (
+ b"<html><head>"
+ b'<meta property="og:title" content="~matrix~" />'
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+ config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
+ config["url_preview_enabled"] = True
+ config["max_spider_size"] = 9999999
+ config["url_preview_ip_range_blacklist"] = (
+ "192.168.1.1",
+ "1.0.0.0/8",
+ "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
+ "2001:800::/21",
+ )
+ config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
+ config["url_preview_accept_language"] = [
+ "en-UK",
+ "en-US;q=0.9",
+ "fr;q=0.8",
+ "*;q=0.7",
+ ]
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ hs = self.setup_test_homeserver(config=config)
+
+ # After the hs is created, modify the parsed oEmbed config (to avoid
+ # messing with files).
+ #
+ # Note that HTTP URLs are used to avoid having to deal with TLS in tests.
+ hs.config.oembed.oembed_patterns = [
+ OEmbedEndpointConfig(
+ api_endpoint="http://publish.twitter.com/oembed",
+ url_patterns=[
+ re.compile(r"http://twitter\.com/.+/status/.+"),
+ ],
+ formats=None,
+ ),
+ OEmbedEndpointConfig(
+ api_endpoint="http://www.hulu.com/api/oembed.{format}",
+ url_patterns=[
+ re.compile(r"http://www\.hulu\.com/watch/.+"),
+ ],
+ formats=["json"],
+ ),
+ ]
+
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.media_repo = hs.get_media_repository()
+ assert self.media_repo.url_previewer is not None
+ self.url_previewer = self.media_repo.url_previewer
+
+ self.lookups: Dict[str, Any] = {}
+
+ class Resolver:
+ def resolveHostName(
+ _self,
+ resolutionReceiver: IResolutionReceiver,
+ hostName: str,
+ portNumber: int = 0,
+ addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+ transportSemantics: str = "TCP",
+ ) -> IResolutionReceiver:
+ resolution = HostResolution(hostName)
+ resolutionReceiver.resolutionBegan(resolution)
+ if hostName not in self.lookups:
+ raise DNSLookupError("OH NO")
+
+ for i in self.lookups[hostName]:
+ resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber))
+ resolutionReceiver.resolutionComplete()
+ return resolutionReceiver
+
+ self.reactor.nameResolver = Resolver() # type: ignore[assignment]
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
+
+ A resource tree is a mapping from path to twisted.web.resource.
+
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
+ def _assert_small_png(self, json_body: JsonDict) -> None:
+ """Assert properties from the SMALL_PNG test image."""
+ self.assertTrue(json_body["og:image"].startswith("mxc://"))
+ self.assertEqual(json_body["og:image:height"], 1)
+ self.assertEqual(json_body["og:image:width"], 1)
+ self.assertEqual(json_body["og:image:type"], "image/png")
+ self.assertEqual(json_body["matrix:image:size"], 67)
+
+ def test_cache_returns_correct_type(self) -> None:
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Check the cache returns the correct response
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ )
+
+ # Check the cache response has the same content
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Clear the in-memory cache
+ self.assertIn("http://matrix.org", self.url_previewer._cache)
+ self.url_previewer._cache.pop("http://matrix.org")
+ self.assertNotIn("http://matrix.org", self.url_previewer._cache)
+
+ # Check the database cache returns the correct response
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ )
+
+ # Check the cache response has the same content
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ def test_non_ascii_preview_httpequiv(self) -> None:
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = (
+ b"<html><head>"
+ b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
+ b'<meta property="og:title" content="\xe4\xea\xe0" />'
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
+
+ def test_video_rejected(self) -> None:
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = b"anything"
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: video/mp4\r\n\r\n"
+ )
+ % (len(end_content))
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "Requested file's content type not allowed for this operation: video/mp4",
+ },
+ )
+
+ def test_audio_rejected(self) -> None:
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = b"anything"
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: audio/aac\r\n\r\n"
+ )
+ % (len(end_content))
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "Requested file's content type not allowed for this operation: audio/aac",
+ },
+ )
+
+ def test_non_ascii_preview_content_type(self) -> None:
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = (
+ b"<html><head>"
+ b'<meta property="og:title" content="\xe4\xea\xe0" />'
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
+
+ def test_overlong_title(self) -> None:
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ end_content = (
+ b"<html><head>"
+ b"<title>" + b"x" * 2000 + b"</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ res = channel.json_body
+ # We should only see the `og:description` field, as `title` is too long and should be stripped out
+ self.assertCountEqual(["og:description"], res.keys())
+
+ def test_ipaddr(self) -> None:
+ """
+ IP addresses can be previewed directly.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ def test_blocked_ip_specific(self) -> None:
+ """
+ Blocked IP addresses, found via DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ )
+
+ # No requests made.
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
+ },
+ )
+
+ def test_blocked_ip_range(self) -> None:
+ """
+ Blocked IP ranges, IPs found over DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ )
+
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
+ },
+ )
+
+ def test_blocked_ip_specific_direct(self) -> None:
+ """
+ Blocked IP addresses, accessed directly, are not spidered.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://192.168.1.1",
+ shorthand=False,
+ )
+
+ # No requests made.
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ self.assertEqual(
+ channel.json_body,
+ {"errcode": "M_UNKNOWN", "error": "IP address blocked"},
+ )
+ self.assertEqual(channel.code, 403)
+
+ def test_blocked_ip_range_direct(self) -> None:
+ """
+ Blocked IP ranges, accessed directly, are not spidered.
+ """
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://1.1.1.2",
+ shorthand=False,
+ )
+
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body,
+ {"errcode": "M_UNKNOWN", "error": "IP address blocked"},
+ )
+
+ def test_blocked_ip_range_whitelisted_ip(self) -> None:
+ """
+ Blocked but then subsequently whitelisted IP addresses can be
+ spidered.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ def test_blocked_ip_with_external_ip(self) -> None:
+ """
+ If a hostname resolves a blocked IP, even if there's a non-blocked one,
+ it will be rejected.
+ """
+ # Hardcode the URL resolving to the IP we want.
+ self.lookups["example.com"] = [
+ (IPv4Address, "1.1.1.2"),
+ (IPv4Address, "10.1.2.3"),
+ ]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
+ },
+ )
+
+ def test_blocked_ipv6_specific(self) -> None:
+ """
+ Blocked IP addresses, found via DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [
+ (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
+ ]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ )
+
+ # No requests made.
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
+ },
+ )
+
+ def test_blocked_ipv6_range(self) -> None:
+ """
+ Blocked IP ranges, IPs found over DNS, are not spidered.
+ """
+ self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ )
+
+ self.assertEqual(channel.code, 502)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "errcode": "M_UNKNOWN",
+ "error": "DNS resolution failure during URL preview generation",
+ },
+ )
+
+ def test_OPTIONS(self) -> None:
+ """
+ OPTIONS returns the OPTIONS.
+ """
+ channel = self.make_request(
+ "OPTIONS",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 204)
+
+ def test_accept_language_config_option(self) -> None:
+ """
+ Accept-Language header is sent to the remote server
+ """
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
+
+ # Build and make a request to the server
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ # Extract Synapse's tcp client
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+
+ # Build a fake remote server to reply with
+ server = AccumulatingProtocol()
+
+ # Connect the two together
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ # Tell Synapse that it has received some data from the remote server
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ # Move the reactor along until we get a response on our original channel
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+ )
+
+ # Check that the server received the Accept-Language header as part
+ # of the request from Synapse
+ self.assertIn(
+ (
+ b"Accept-Language: en-UK\r\n"
+ b"Accept-Language: en-US;q=0.9\r\n"
+ b"Accept-Language: fr;q=0.8\r\n"
+ b"Accept-Language: *;q=0.7"
+ ),
+ server.data,
+ )
+
+ def test_image(self) -> None:
+ """An image should be precached if mentioned in the HTML."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
+
+ result = (
+ b"""<html><body><img src="http://cdn.matrix.org/foo.png"></body></html>"""
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ # Respond with the HTML.
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+ self.pump()
+
+ # Respond with the photo.
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: image/png\r\n\r\n"
+ )
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+ self.pump()
+
+ # The image should be in the result.
+ self.assertEqual(channel.code, 200)
+ self._assert_small_png(channel.json_body)
+
+ def test_nonexistent_image(self) -> None:
+ """If the preview image doesn't exist, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ result = (
+ b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+
+ # There should not be a second connection.
+ self.assertEqual(len(self.reactor.tcpClients), 1)
+
+ # The image should not be in the result.
+ self.assertEqual(channel.code, 200)
+ self.assertNotIn("og:image", channel.json_body)
+
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "cdn.matrix.org"}]}
+ )
+ def test_image_blocked(self) -> None:
+ """If the preview image doesn't exist, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
+
+ result = (
+ b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
+ )
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+ self.pump()
+
+ # There should not be a second connection.
+ self.assertEqual(len(self.reactor.tcpClients), 1)
+
+ # The image should not be in the result.
+ self.assertEqual(channel.code, 200)
+ self.assertNotIn("og:image", channel.json_body)
+
+ def test_oembed_failure(self) -> None:
+ """If the autodiscovered oEmbed URL fails, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ result = b"""
+ <title>oEmbed Autodiscovery Fail</title>
+ <link rel="alternate" type="application/json+oembed"
+ href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # The image should not be in the result.
+ self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail")
+
+ def test_data_url(self) -> None:
+ """
+ Requesting to preview a data URL is not supported.
+ """
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ data = base64.b64encode(SMALL_PNG).decode()
+
+ query_params = urlencode(
+ {
+ "url": f'<html><head><img src="data:image/png;base64,{data}" /></head></html>'
+ }
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?{query_params}",
+ shorthand=False,
+ )
+ self.pump()
+
+ self.assertEqual(channel.code, 500)
+
+ def test_inline_data_url(self) -> None:
+ """
+ An inline image (as a data URL) should be parsed properly.
+ """
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ data = base64.b64encode(SMALL_PNG)
+
+ end_content = (
+ b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
+ ) % (data,)
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self._assert_small_png(channel.json_body)
+
+ def test_oembed_photo(self) -> None:
+ """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+
+ self.pump()
+
+ # Ensure a second request is made to the photo URL.
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: image/png\r\n\r\n"
+ )
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+
+ self.pump()
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/matrixdotorg", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
+ self._assert_small_png(body)
+
+ def test_oembed_rich(self) -> None:
+ """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ # Note that this provides the author, not the title.
+ "author_name": "Alice",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ # Double check that the proper host is being connected to. (Note that
+ # twitter.com can't be resolved so this is already implicitly checked.)
+ self.assertIn(b"\r\nHost: publish.twitter.com\r\n", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(
+ body,
+ {
+ "og:url": "http://twitter.com/matrixdotorg/status/12345",
+ "og:title": "Alice",
+ "og:description": "Content Preview",
+ },
+ )
+
+ def test_oembed_format(self) -> None:
+ """Test an oEmbed endpoint which requires the format in the URL."""
+ self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.hulu.com/watch/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ # The {format} should have been turned into json.
+ self.assertIn(b"/api/oembed.json", server.data)
+ # A URL parameter of format=json should be provided.
+ self.assertIn(b"format=json", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(
+ body,
+ {
+ "og:url": "http://www.hulu.com/watch/12345",
+ "og:description": "Content Preview",
+ },
+ )
+
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
+ )
+ def test_oembed_blocked(self) -> None:
+ """The oEmbed URL should not be downloaded if the oEmbed URL is blocked."""
+ self.lookups["twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_oembed_autodiscovery(self) -> None:
+ """
+ Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
+ 1. Request a preview of a URL which is not known to the oEmbed code.
+ 2. It returns HTML including a link to an oEmbed preview.
+ 3. The oEmbed preview is requested and returns a URL for an image.
+ 4. The image is requested for thumbnailing.
+ """
+ # This is a little cheesy in that we use the www subdomain (which isn't the
+ # list of oEmbed patterns) to get "raw" HTML response.
+ self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = b"""
+ <link rel="alternate" type="application/json+oembed"
+ href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+ self.pump()
+
+ # The oEmbed response.
+ result2 = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result2).encode("utf-8")
+
+ # Ensure a second request is made to the oEmbed URL.
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+ self.pump()
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/oembed?", server.data)
+
+ # Ensure a third request is made to the photo URL.
+ client = self.reactor.tcpClients[2][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b"Content-Type: image/png\r\n\r\n"
+ )
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+ self.pump()
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/matrixdotorg", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(
+ body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345"
+ )
+ self._assert_small_png(body)
+
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
+ )
+ def test_oembed_autodiscovery_blocked(self) -> None:
+ """
+ If the discovered oEmbed URL is blocked, it should be discarded.
+ """
+ # This is a little cheesy in that we use the www subdomain (which isn't the
+ # list of oEmbed patterns) to get "raw" HTML response.
+ self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.4")]
+
+ result = b"""
+ <title>Test</title>
+ <link rel="alternate" type="application/json+oembed"
+ href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+
+ # Ensure there's no additional connections.
+ self.assertEqual(len(self.reactor.tcpClients), 1)
+
+ # Ensure the URL is what was requested.
+ self.assertIn(b"\r\nHost: www.twitter.com\r\n", server.data)
+
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ self.assertEqual(body["og:title"], "Test")
+ self.assertNotIn("og:image", body)
+
+ def _download_image(self) -> Tuple[str, str]:
+ """Downloads an image into the URL cache.
+ Returns:
+ A (host, media_id) tuple representing the MXC URI of the image.
+ """
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://cdn.twitter.com/matrixdotorg",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: image/png\r\n\r\n"
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ mxc_uri = body["og:image"]
+ host, _port, media_id = parse_and_validate_mxc_uri(mxc_uri)
+ self.assertIsNone(_port)
+ return host, media_id
+
+ def test_storage_providers_exclude_files(self) -> None:
+ """Test that files are not stored in or fetched from storage providers."""
+ host, media_id = self._download_image()
+
+ rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id)
+ media_store_path = os.path.join(self.media_store_path, rel_file_path)
+ storage_provider_path = os.path.join(self.storage_path, rel_file_path)
+
+ # Check storage
+ self.assertTrue(os.path.isfile(media_store_path))
+ self.assertFalse(
+ os.path.isfile(storage_provider_path),
+ "URL cache file was unexpectedly stored in a storage provider",
+ )
+
+ # Check fetching
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/media/v3/download/{host}/{media_id}",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # Move cached file into the storage provider
+ os.makedirs(os.path.dirname(storage_provider_path), exist_ok=True)
+ os.rename(media_store_path, storage_provider_path)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/media/v3/download/{host}/{media_id}",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(
+ channel.code,
+ 404,
+ "URL cache file was unexpectedly retrieved from a storage provider",
+ )
+
+ def test_storage_providers_exclude_thumbnails(self) -> None:
+ """Test that thumbnails are not stored in or fetched from storage providers."""
+ host, media_id = self._download_image()
+
+ rel_thumbnail_path = (
+ self.media_repo.filepaths.url_cache_thumbnail_directory_rel(media_id)
+ )
+ media_store_thumbnail_path = os.path.join(
+ self.media_store_path, rel_thumbnail_path
+ )
+ storage_provider_thumbnail_path = os.path.join(
+ self.storage_path, rel_thumbnail_path
+ )
+
+ # Check storage
+ self.assertTrue(os.path.isdir(media_store_thumbnail_path))
+ self.assertFalse(
+ os.path.isdir(storage_provider_thumbnail_path),
+ "URL cache thumbnails were unexpectedly stored in a storage provider",
+ )
+
+ # Check fetching
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # Remove the original, otherwise thumbnails will regenerate
+ rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id)
+ media_store_path = os.path.join(self.media_store_path, rel_file_path)
+ os.remove(media_store_path)
+
+ # Move cached thumbnails into the storage provider
+ os.makedirs(os.path.dirname(storage_provider_thumbnail_path), exist_ok=True)
+ os.rename(media_store_thumbnail_path, storage_provider_thumbnail_path)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(
+ channel.code,
+ 404,
+ "URL cache thumbnail was unexpectedly retrieved from a storage provider",
+ )
+
+ def test_cache_expiry(self) -> None:
+ """Test that URL cache files and thumbnails are cleaned up properly on expiry."""
+ _host, media_id = self._download_image()
+
+ file_path = self.media_repo.filepaths.url_cache_filepath(media_id)
+ file_dirs = self.media_repo.filepaths.url_cache_filepath_dirs_to_delete(
+ media_id
+ )
+ thumbnail_dir = self.media_repo.filepaths.url_cache_thumbnail_directory(
+ media_id
+ )
+ thumbnail_dirs = self.media_repo.filepaths.url_cache_thumbnail_dirs_to_delete(
+ media_id
+ )
+
+ self.assertTrue(os.path.isfile(file_path))
+ self.assertTrue(os.path.isdir(thumbnail_dir))
+
+ self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1)
+ self.get_success(self.url_previewer._expire_url_cache_data())
+
+ for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs:
+ self.assertFalse(
+ os.path.exists(path),
+ f"{os.path.relpath(path, self.media_store_path)} was not deleted",
+ )
+
+ @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]})
+ def test_blocked_port(self) -> None:
+ """Tests that blocking URLs with a port makes previewing such URLs
+ fail with a 403 error and doesn't impact other previews.
+ """
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ bad_url = quote("http://matrix.org:8888/foo")
+ good_url = quote("http://matrix.org/foo")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url="
+ + bad_url,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 403, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url="
+ + good_url,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
+ % (len(self.end_content),)
+ + self.end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ @unittest.override_config(
+ {"url_preview_url_blacklist": [{"netloc": "example.com"}]}
+ )
+ def test_blocked_url(self) -> None:
+ """Tests that blocking URLs with a host makes previewing such URLs
+ fail with a 403 error.
+ """
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
+
+ bad_url = quote("http://example.com/foo")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url="
+ + bad_url,
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 403, channel.result)
+
+
+class UnstableMediaConfigTest(unittest.HomeserverTestCase):
+ servlets = [
+ media.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
+ config = self.default_config()
+ config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+
+ def test_media_config(self) -> None:
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc3916/media/config",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size
+ )
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 2f994ad553..5ef501c6d5 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -18,15 +18,39 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
+from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class SendToDeviceTestCase(HomeserverTestCase):
+ """
+ Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
}
self.assertEqual(channel.json_body["to_device"], expected_result)
- # it should re-appear if we do another sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # it should re-appear if we do another sync because the to-device message is not
+ # deleted until we acknowledge it by sending a `?since=...` parameter in the
+ # next sync request corresponding to the `next_batch` value from the response.
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 200, chan.result)
- # now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # now sync: we should get two of the three (because burst_count=2)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
- {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
+ {
+ "sender": user1,
+ "type": "m.room_key_request",
+ "content": {"idx": i},
+ },
)
sync_token = channel.json_body["next_batch"]
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 417a87feb2..daeb1d3ddd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,7 +21,7 @@
import json
from typing import List
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device list (`device_lists`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def test_receiving_local_device_list_changes(self) -> None:
+ """Tests that a local users that share a room receive each other's device list
+ changes.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # Create a room for them to coexist peacefully in
+ new_room_id = self.helper.create_room_as(
+ alice_user_id, is_public=True, tok=alice_access_token
+ )
+ self.assertIsNotNone(new_room_id)
+
+ # Have Bob join the room
+ self.helper.invite(
+ new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+ )
+ self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+ # Now have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync contains the updated device list.
+ # If not, the client would only receive the device list update on the
+ # *next* sync.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+ def test_not_receiving_local_device_list_changes(self) -> None:
+ """Tests a local users DO NOT receive device updates from each other if they do not
+ share a room.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # These users do not share a room. They are lonely.
+
+ # Have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync does not contain the updated device list.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertNotIn(
+ alice_user_id, changed_device_lists, bob_sync_channel.json_body
+ )
+
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
- device_id = "TESTDEVICE"
+ test_device_id = "TESTDEVICE"
# Register a user and login, creating a device
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey", device_id=device_id)
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
# Request an initial sync
- channel = self.make_request("GET", "/sync", access_token=self.tok)
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
- f"/sync?since={next_batch}&timeout=30000",
- access_token=self.tok,
+ f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
+ access_token=alice_access_token,
await_result=False,
)
# Change our device's display name
channel = self.make_request(
"PUT",
- f"devices/{device_id}",
+ f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
- access_token=self.tok,
+ access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
).get("changed", [])
self.assertIn(
- self.user_id, device_list_changes, incremental_sync_channel.json_body
+ alice_user_id, device_list_changes, incremental_sync_channel.json_body
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_one_time_keys_count`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_one_time_keys(self) -> None:
+ """
+ Tests when no one time keys set, it still has the default `signed_curve25519` in
+ `device_one_time_keys_count`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ {"signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that one time keys for the device/user are counted correctly in the `/sync`
+ response
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Upload one time keys for the user/device
+ keys: JsonDict = {
+ "alg1:k1": "key1",
+ "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
+ "alg2:k3": {"key": "key3"},
+ }
+ res = self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id, test_device_id, {"one_time_keys": keys}
+ )
+ )
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ self.assertDictEqual(
+ res,
+ {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = self.hs.get_datastores().main
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_unused_fallback_key(self) -> None:
+ """
+ Test when no unused fallback key is set, it just returns an empty list. The MSC
+ says "The device_unused_fallback_key_types parameter must be present if the
+ server supports fallback keys.",
+ https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ [],
+ channel.json_body["device_unused_fallback_key_types"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that device unused fallback key type is returned correctly in the `/sync`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # We shouldn't have any unused fallback keys yet
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(res, [])
+
+ # Upload a fallback key for the user/device
+ fallback_key = {"alg1:k1": "fallback_key1"}
+ self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id,
+ test_device_id,
+ {"fallback_keys": fallback_key},
+ )
+ )
+ # We should now have an unused alg1 key
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(fallback_res, ["alg1"], fallback_res)
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for the unused fallback key types
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ ["alg1"],
+ channel.json_body["device_unused_fallback_key_types"],
)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index fe00afe198..7362bde7ab 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -170,6 +170,7 @@ class RestHelper:
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
+ extra_data: Optional[dict] = None,
) -> JsonDict:
return self.change_membership(
room=room,
@@ -178,6 +179,7 @@ class RestHelper:
tok=tok,
membership=Membership.INVITE,
expect_code=expect_code,
+ extra_data=extra_data,
)
def join(
diff --git a/tests/server.py b/tests/server.py
index 434be3d22c..f3a917f835 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -85,6 +85,7 @@ from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
+from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.events.presence_router import load_legacy_presence_router
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
@@ -1156,6 +1157,11 @@ def setup_test_homeserver(
for module, module_config in hs.config.modules.loaded_modules:
module(config=module_config, api=module_api)
+ if hs.config.auto_accept_invites.enabled:
+ # Start the local auto_accept_invites module.
+ m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+ logger.info("Loaded local module %s", m)
+
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
|