diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 64da7b0dbf..5510b73dc1 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -367,7 +367,7 @@ jobs:
# Run Complement
- run: |
set -o pipefail
- go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt
+ go test -v -json -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt
shell: bash
name: Run Complement Tests
env:
diff --git a/CHANGES.md b/CHANGES.md
index 78498c10b5..b0244a16f0 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,8 +1,43 @@
-Synapse 1.55.0rc1 (2022-03-15)
-==============================
+Synapse 1.55.2 (2022-03-24)
+===========================
+
+This patch version reverts the earlier fixes from Synapse 1.55.1, which could cause problems in certain deployments, and instead adds a cap to the version of Jinja to be installed. Again, this is to fix an incompatibility with version 3.1.0 of the [Jinja](https://pypi.org/project/Jinja2/) library, and again, deployments of Synapse using the `matrixdotorg/synapse` Docker image or Debian packages from packages.matrix.org are not affected.
+
+Internal Changes
+----------------
+
+- Pin Jinja to <3.1.0, as Synapse fails to start with Jinja 3.1.0. ([\#12297](https://github.com/matrix-org/synapse/issues/12297))
+- Revert changes from 1.55.1 as they caused problems with older versions of Jinja ([\#12296](https://github.com/matrix-org/synapse/issues/12296))
+
+
+Synapse 1.55.1 (2022-03-24)
+===========================
+
+This is a patch release that fixes an incompatibility with version 3.1.0 of the [Jinja](https://pypi.org/project/Jinja2/) library, released on March 24th, 2022. Deployments of Synapse using the `matrixdotorg/synapse` Docker image or Debian packages from packages.matrix.org are not affected.
+
+Internal Changes
+----------------
+
+- Remove uses of the long-deprecated `jinja2.Markup` which would prevent Synapse from starting with Jinja 3.1.0 or above installed. ([\#12289](https://github.com/matrix-org/synapse/issues/12289))
+
+
+Synapse 1.55.0 (2022-03-22)
+===========================
This release removes a workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. **This breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700))**; Mjolnir users should upgrade Mjolnir before upgrading Synapse to this version.
+This release also moves the location of the `synctl` script; see the [upgrade notes](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#synctl-script-has-been-moved) for more details.
+
+
+Internal Changes
+----------------
+
+- Tweak copy for default Single Sign-On account details template to better adhere to mobile app store guidelines. ([\#12265](https://github.com/matrix-org/synapse/issues/12265), [\#12260](https://github.com/matrix-org/synapse/issues/12260))
+
+
+Synapse 1.55.0rc1 (2022-03-15)
+==============================
+
Features
--------
@@ -38,6 +73,7 @@ Deprecations and Removals
-------------------------
- **Remove workaround introduced in Synapse 1.50.0 for Mjolnir compatibility. Breaks compatibility with Mjolnir 1.3.1 and earlier. ([\#11700](https://github.com/matrix-org/synapse/issues/11700))**
+- **`synctl` has been moved into into `synapse._scripts` and is exposed as an entry point; see [upgrade notes](https://github.com/matrix-org/synapse/blob/develop/docs/upgrade.md#synctl-script-has-been-moved). ([\#12140](https://github.com/matrix-org/synapse/issues/12140))
- Remove backwards compatibilty with pagination tokens from the `/relations` and `/aggregations` endpoints generated from Synapse < v1.52.0. ([\#12138](https://github.com/matrix-org/synapse/issues/12138))
- The groups/communities feature in Synapse has been deprecated. ([\#12200](https://github.com/matrix-org/synapse/issues/12200))
@@ -56,7 +92,6 @@ Internal Changes
- Fix CI not attaching source distributions and wheels to the GitHub releases. ([\#12131](https://github.com/matrix-org/synapse/issues/12131))
- Remove unused mocks from `test_typing`. ([\#12136](https://github.com/matrix-org/synapse/issues/12136))
- Give `scripts-dev` scripts suffixes for neater CI config. ([\#12137](https://github.com/matrix-org/synapse/issues/12137))
-- Move `synctl` into `synapse._scripts` and expose as an entry point. ([\#12140](https://github.com/matrix-org/synapse/issues/12140))
- Move the snapcraft configuration file to `contrib`. ([\#12142](https://github.com/matrix-org/synapse/issues/12142))
- Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI. ([\#12144](https://github.com/matrix-org/synapse/issues/12144))
- Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI. ([\#12145](https://github.com/matrix-org/synapse/issues/12145))
diff --git a/changelog.d/12036.misc b/changelog.d/12036.misc
new file mode 100644
index 0000000000..d2996730cc
--- /dev/null
+++ b/changelog.d/12036.misc
@@ -0,0 +1 @@
+Rename `shared_rooms` to `mutual_rooms` (MSC2666), as per proposal changes.
\ No newline at end of file
diff --git a/changelog.d/12038.misc b/changelog.d/12038.misc
new file mode 100644
index 0000000000..e2a65726b6
--- /dev/null
+++ b/changelog.d/12038.misc
@@ -0,0 +1 @@
+Remove check on `update_user_directory` for shared rooms handler (MSC2666), and update/expand documentation.
\ No newline at end of file
diff --git a/changelog.d/12083.misc b/changelog.d/12083.misc
new file mode 100644
index 0000000000..88fd6b92ee
--- /dev/null
+++ b/changelog.d/12083.misc
@@ -0,0 +1 @@
+Refactor `create_new_client_event` to use a new parameter, `state_event_ids`, which accurately describes the usage with [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) instead of abusing `auth_event_ids`.
diff --git a/changelog.d/12087.bugfix b/changelog.d/12087.bugfix
new file mode 100644
index 0000000000..6dacdddd0d
--- /dev/null
+++ b/changelog.d/12087.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which caused the `/_matrix/federation/v1/state` and `.../state_ids` endpoints to return incorrect or invalid data when called for an event which we have stored as an "outlier".
diff --git a/changelog.d/12091.misc b/changelog.d/12091.misc
new file mode 100644
index 0000000000..def44987b4
--- /dev/null
+++ b/changelog.d/12091.misc
@@ -0,0 +1 @@
+Refuse to start if registration is enabled without email, captcha, or token-based verification unless new config flag `enable_registration_without_verification` is set.
diff --git a/changelog.d/12227.bugfix b/changelog.d/12227.bugfix
new file mode 100644
index 0000000000..1a7dccf465
--- /dev/null
+++ b/changelog.d/12227.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where events from ignored users were still considered for relations.
diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix
new file mode 100644
index 0000000000..4755777139
--- /dev/null
+++ b/changelog.d/12228.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads.
diff --git a/changelog.d/12232.bugfix b/changelog.d/12232.bugfix
new file mode 100644
index 0000000000..1a7dccf465
--- /dev/null
+++ b/changelog.d/12232.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where events from ignored users were still considered for relations.
diff --git a/changelog.d/12232.misc b/changelog.d/12232.misc
deleted file mode 100644
index 4a4132edff..0000000000
--- a/changelog.d/12232.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor relations tests to improve code re-use.
diff --git a/changelog.d/12227.misc b/changelog.d/12237.misc
index 41c9dcbd37..41c9dcbd37 100644
--- a/changelog.d/12227.misc
+++ b/changelog.d/12237.misc
diff --git a/changelog.d/12240.misc b/changelog.d/12240.misc
new file mode 100644
index 0000000000..c5b6356799
--- /dev/null
+++ b/changelog.d/12240.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
\ No newline at end of file
diff --git a/changelog.d/12242.misc b/changelog.d/12242.misc
new file mode 100644
index 0000000000..38e7e0f7d1
--- /dev/null
+++ b/changelog.d/12242.misc
@@ -0,0 +1 @@
+Generate announcement links in the release script.
diff --git a/changelog.d/12243.doc b/changelog.d/12243.doc
new file mode 100644
index 0000000000..b2031f0a40
--- /dev/null
+++ b/changelog.d/12243.doc
@@ -0,0 +1 @@
+Remove incorrect prefixes in the worker documentation for some endpoints.
diff --git a/changelog.d/12244.misc b/changelog.d/12244.misc
new file mode 100644
index 0000000000..950d48e4c6
--- /dev/null
+++ b/changelog.d/12244.misc
@@ -0,0 +1 @@
+Improve error message when dependencies check finds a broken installation.
\ No newline at end of file
diff --git a/changelog.d/12246.doc b/changelog.d/12246.doc
new file mode 100644
index 0000000000..e7fcc1b99c
--- /dev/null
+++ b/changelog.d/12246.doc
@@ -0,0 +1 @@
+Correct `check_username_for_spam` annotations and docs.
\ No newline at end of file
diff --git a/changelog.d/12248.misc b/changelog.d/12248.misc
new file mode 100644
index 0000000000..2b1290d1e1
--- /dev/null
+++ b/changelog.d/12248.misc
@@ -0,0 +1 @@
+Add missing type hints for storage.
\ No newline at end of file
diff --git a/changelog.d/12250.feature b/changelog.d/12250.feature
new file mode 100644
index 0000000000..29a2724457
--- /dev/null
+++ b/changelog.d/12250.feature
@@ -0,0 +1 @@
+Allow registering admin users using the module API. Contributed by Famedly.
diff --git a/changelog.d/12256.misc b/changelog.d/12256.misc
new file mode 100644
index 0000000000..c5b6356799
--- /dev/null
+++ b/changelog.d/12256.misc
@@ -0,0 +1 @@
+Add type hints to tests files.
\ No newline at end of file
diff --git a/changelog.d/12258.misc b/changelog.d/12258.misc
new file mode 100644
index 0000000000..80024c8e91
--- /dev/null
+++ b/changelog.d/12258.misc
@@ -0,0 +1 @@
+Compress metrics HTTP resource when enabled. Contributed by Nick @ Beeper.
diff --git a/changelog.d/12261.bugfix b/changelog.d/12261.bugfix
new file mode 100644
index 0000000000..1bfde4c380
--- /dev/null
+++ b/changelog.d/12261.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.52 where admins could not deactivate and GDPR-erase a user if Synapse was configured with limits on avatars.
diff --git a/changelog.d/12262.misc b/changelog.d/12262.misc
new file mode 100644
index 0000000000..574ac4752c
--- /dev/null
+++ b/changelog.d/12262.misc
@@ -0,0 +1 @@
+Refuse to start if DB has non-`C` locale, unless config flag `allow_unsafe_db_locale` is set to true.
\ No newline at end of file
diff --git a/changelog.d/12266.misc b/changelog.d/12266.misc
new file mode 100644
index 0000000000..59e2718370
--- /dev/null
+++ b/changelog.d/12266.misc
@@ -0,0 +1 @@
+Optionally include account validity expiration information to experimental [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) account status responses.
diff --git a/changelog.d/12269.misc b/changelog.d/12269.misc
new file mode 100644
index 0000000000..ed79cbb528
--- /dev/null
+++ b/changelog.d/12269.misc
@@ -0,0 +1 @@
+Use type stubs for `psycopg2`.
diff --git a/changelog.d/12272.misc b/changelog.d/12272.misc
new file mode 100644
index 0000000000..95589f3361
--- /dev/null
+++ b/changelog.d/12272.misc
@@ -0,0 +1 @@
+Add a new cache `_get_membership_from_event_id` to speed up push rule calculations in large rooms.
diff --git a/changelog.d/12275.doc b/changelog.d/12275.doc
new file mode 100644
index 0000000000..2e26ad21eb
--- /dev/null
+++ b/changelog.d/12275.doc
@@ -0,0 +1 @@
+Corrected Authentik OpenID typo, added helpful note for troubleshooting. Contributed by @IronTooch.
diff --git a/changelog.d/12283.misc b/changelog.d/12283.misc
new file mode 100644
index 0000000000..e9f2208500
--- /dev/null
+++ b/changelog.d/12283.misc
@@ -0,0 +1 @@
+Re-enable Complement concurrency in CI.
diff --git a/changelog.d/12285.bugfix b/changelog.d/12285.bugfix
new file mode 100644
index 0000000000..1a7dccf465
--- /dev/null
+++ b/changelog.d/12285.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where events from ignored users were still considered for relations.
diff --git a/changelog.d/12288.misc b/changelog.d/12288.misc
new file mode 100644
index 0000000000..ee8fbfd290
--- /dev/null
+++ b/changelog.d/12288.misc
@@ -0,0 +1 @@
+Refuse to start if DB has non-`C` locale, unless config flag `allow_unsafe_db_locale` is set to true.
diff --git a/changelog.d/12291.misc b/changelog.d/12291.misc
new file mode 100644
index 0000000000..b55dd68f92
--- /dev/null
+++ b/changelog.d/12291.misc
@@ -0,0 +1 @@
+Remove unused test utilities.
diff --git a/changelog.d/12301.misc b/changelog.d/12301.misc
new file mode 100644
index 0000000000..a4cd94ee5e
--- /dev/null
+++ b/changelog.d/12301.misc
@@ -0,0 +1 @@
+Enhance logging for inbound federation events.
diff --git a/debian/changelog b/debian/changelog
index 09ef24ebb0..3c899e6024 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,21 @@
+matrix-synapse-py3 (1.55.2) stable; urgency=medium
+
+ * New synapse release 1.55.2.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 24 Mar 2022 19:07:11 +0000
+
+matrix-synapse-py3 (1.55.1) stable; urgency=medium
+
+ * New synapse release 1.55.1.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 24 Mar 2022 17:44:23 +0000
+
+matrix-synapse-py3 (1.55.0) stable; urgency=medium
+
+ * New synapse release 1.55.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 22 Mar 2022 13:59:26 +0000
+
matrix-synapse-py3 (1.55.0~rc1) stable; urgency=medium
* New synapse release 1.55.0~rc1.
diff --git a/demo/start.sh b/demo/start.sh
index 55e69685e3..5a9972d24c 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -38,6 +38,7 @@ for port in 8080 8081 8082; do
printf '\n\n# Customisation made by demo/start.sh\n\n'
echo "public_baseurl: http://localhost:$port/"
echo 'enable_registration: true'
+ echo 'enable_registration_without_verification: true'
echo ''
# Warning, this heredoc depends on the interaction of tabs and spaces.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index 2b672b78f9..472d957180 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -172,7 +172,7 @@ any of the subsequent implementations of this callback.
_First introduced in Synapse v1.37.0_
```python
-async def check_username_for_spam(user_profile: Dict[str, str]) -> bool
+async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool
```
Called when computing search results in the user directory. The module must return a
@@ -182,9 +182,11 @@ search results; otherwise return `False`.
The profile is represented as a dictionary with the following keys:
-* `user_id`: The Matrix ID for this user.
-* `display_name`: The user's display name.
-* `avatar_url`: The `mxc://` URL to the user's avatar.
+* `user_id: str`. The Matrix ID for this user.
+* `display_name: Optional[str]`. The user's display name, or `None` if this user
+ has not set a display name.
+* `avatar_url: Optional[str]`. The `mxc://` URL to the user's avatar, or `None`
+ if this user has not set an avatar.
The module is given a copy of the original dictionary, so modifying it from within the
module cannot modify a user's profile when included in user directory search results.
diff --git a/docs/openid.md b/docs/openid.md
index 171ea3b712..19cacaafef 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -225,6 +225,8 @@ oidc_providers:
3. Create an application for synapse in Authentik and link it to the provider.
4. Note the slug of your application, Client ID and Client Secret.
+Note: RSA keys must be used for signing for Authentik, ECC keys do not work.
+
Synapse config:
```yaml
oidc_providers:
@@ -240,7 +242,7 @@ oidc_providers:
- "email"
user_mapping_provider:
config:
- localpart_template: "{{ user.preferred_username }}}"
+ localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.preferred_username|capitalize }}" # TO BE FILLED: If your users have names in Authentik and you want those in Synapse, this should be replaced with user.name|capitalize.
```
diff --git a/docs/postgres.md b/docs/postgres.md
index de4e2ba4b7..cbc32e1836 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -234,12 +234,13 @@ host all all ::1/128 ident
### Fixing incorrect `COLLATE` or `CTYPE`
Synapse will refuse to set up a new database if it has the wrong values of
-`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using
-different locales can cause issues if the locale library is updated from
+`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
+of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the
+`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
underneath the database, or if a different version of the locale is used on any
replicas.
-The safest way to fix the issue is to dump the database and recreate it with
+If you have a databse with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
the correct locale parameter (as shown above). It is also possible to change the
parameters on a live database and run a `REINDEX` on the entire database,
however extreme care must be taken to avoid database corruption.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 36c6c56e58..a21b48ab2e 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -783,6 +783,12 @@ caches:
# 'txn_limit' gives the maximum number of transactions to run per connection
# before reconnecting. Defaults to 0, which means no limit.
#
+# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
+# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
+# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
+# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
+# https://wiki.postgresql.org/wiki/Locale_data_changes
+#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
@@ -1212,10 +1218,18 @@ oembed:
# Registration can be rate-limited using the parameters in the "Ratelimiting"
# section of this file.
-# Enable registration for new users.
+# Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
+# you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
+# without any verification, you must also set `enable_registration_without_verification`, found below.
#
#enable_registration: false
+# Enable registration without email or captcha verification. Note: this option is *not* recommended,
+# as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
+# unless `enable_registration` is also enabled.
+#
+#enable_registration_without_verification: true
+
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
diff --git a/docs/upgrade.md b/docs/upgrade.md
index f9ac605e7b..062e823333 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -99,8 +99,21 @@ experimental_features:
groups_enabled: false
```
+## Change in behaviour for PostgreSQL databases with unsafe locale
+
+Synapse now refuses to start when using PostgreSQL with non-`C` values for `COLLATE` and
+`CTYPE` unless the config flag `allow_unsafe_locale`, found in the database section of
+the configuration file, is set to `true`. See the [PostgreSQL documentation](https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype)
+for more information and instructions on how to fix a database with incorrect values.
+
# Upgrading to v1.55.0
+## Open registration without verification is now disabled by default
+
+Synapse will refuse to start if registration is enabled without email, captcha, or token-based verification unless the new config
+flag `enable_registration_without_verification` is set to "true".
+
+
## `synctl` script has been moved
The `synctl` script
diff --git a/docs/workers.md b/docs/workers.md
index 8751134e65..8ac95e39bb 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -185,8 +185,8 @@ worker: refer to the [stream writers](#stream-writers) section below for further
information.
# Sync requests
- ^/_matrix/client/(v2_alpha|r0|v3)/sync$
- ^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$
+ ^/_matrix/client/(r0|v3)/sync$
+ ^/_matrix/client/(api/v1|r0|v3)/events$
^/_matrix/client/(api/v1|r0|v3)/initialSync$
^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$
@@ -200,13 +200,9 @@ information.
^/_matrix/federation/v1/query/
^/_matrix/federation/v1/make_join/
^/_matrix/federation/v1/make_leave/
- ^/_matrix/federation/v1/send_join/
- ^/_matrix/federation/v2/send_join/
- ^/_matrix/federation/v1/send_leave/
- ^/_matrix/federation/v2/send_leave/
- ^/_matrix/federation/v1/invite/
- ^/_matrix/federation/v2/invite/
- ^/_matrix/federation/v1/query_auth/
+ ^/_matrix/federation/(v1|v2)/send_join/
+ ^/_matrix/federation/(v1|v2)/send_leave/
+ ^/_matrix/federation/(v1|v2)/invite/
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/user/devices/
@@ -274,6 +270,8 @@ information.
Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/federation/v1/groups/
+ ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
+ ^/_matrix/client/(r0|v3|unstable)/groups/
Pagination requests can also be handled, but all requests for a given
room must be routed to the same instance. Additionally, care must be taken to
@@ -397,23 +395,23 @@ the stream writer for the `typing` stream:
The following endpoints should be routed directly to the worker configured as
the stream writer for the `to_device` stream:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
+ ^/_matrix/client/(r0|v3|unstable)/sendToDevice/
##### The `account_data` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `account_data` stream:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
- ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
+ ^/_matrix/client/(r0|v3|unstable)/.*/tags
+ ^/_matrix/client/(r0|v3|unstable)/.*/account_data
##### The `receipts` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `receipts` stream:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
- ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
+ ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt
+ ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers
##### The `presence` stream
@@ -528,19 +526,28 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for
Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$
+ ^/_matrix/client/(r0|v3|unstable)/user_directory/search$
-When using this worker you must also set `update_user_directory: False` in the
+When using this worker you must also set `update_user_directory: false` in the
shared configuration file to stop the main synapse running background
jobs related to updating the user directory.
+Above endpoint is not *required* to be routed to this worker. By default,
+`update_user_directory` is set to `true`, which means the main process
+will handle updates. All workers configured with `client` can handle the above
+endpoint as long as either this worker or the main process are configured to
+handle it, and are online.
+
+If `update_user_directory` is set to `false`, and this worker is not running,
+the above endpoint may give outdated results.
+
### `synapse.app.frontend_proxy`
Proxies some frequently-requested client endpoints to add caching and remove
load from the main synapse. It can handle REST endpoints matching the following
regular expressions:
- ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload
+ ^/_matrix/client/(r0|v3|unstable)/keys/upload
If `use_presence` is False in the homeserver config, it can also handle REST
endpoints matching the following regular expressions:
diff --git a/mypy.ini b/mypy.ini
index 651d245329..0c6f1c0c30 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -41,9 +41,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
- |synapse/storage/databases/main/group_server.py
- |synapse/storage/databases/main/metrics.py
- |synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
@@ -65,9 +62,6 @@ exclude = (?x)
|tests/federation/test_federation_server.py
|tests/federation/transport/test_knocking.py
|tests/federation/transport/test_server.py
- |tests/handlers/test_cas.py
- |tests/handlers/test_federation.py
- |tests/handlers/test_presence.py
|tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
@@ -79,7 +73,6 @@ exclude = (?x)
|tests/logging/test_terse_json.py
|tests/module_api/test_api.py
|tests/push/test_email.py
- |tests/push/test_http.py
|tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
@@ -88,12 +81,7 @@ exclude = (?x)
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
- |tests/storage/test_background_update.py
|tests/storage/test_base.py
- |tests/storage/test_client_ips.py
- |tests/storage/test_database.py
- |tests/storage/test_event_federation.py
- |tests/storage/test_id_generators.py
|tests/storage/test_roommember.py
|tests/test_metrics.py
|tests/test_phone_home.py
diff --git a/poetry.lock b/poetry.lock
index 8ca8a86b6d..c6d569129d 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1405,6 +1405,14 @@ optional = false
python-versions = "*"
[[package]]
+name = "types-psycopg2"
+version = "2.9.9"
+description = "Typing stubs for psycopg2"
+category = "dev"
+optional = false
+python-versions = "*"
+
+[[package]]
name = "types-pyopenssl"
version = "22.0.0"
description = "Typing stubs for pyOpenSSL"
@@ -1582,7 +1590,7 @@ url_preview = ["lxml"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7"
-content-hash = "0715b0d20c731251e080051ef7c3fbb21d8b15e013670591b4517939cf8ce76b"
+content-hash = "438c6786f123522daedffe0f44fefc49689cbb3e1bc5dbf05c3ff5a8ab72ad16"
[metadata.files]
appdirs = [
@@ -2634,6 +2642,10 @@ types-pillow = [
{file = "types-Pillow-9.0.6.tar.gz", hash = "sha256:79b350b1188c080c27558429f1e119e69c9f020b877a82df761d9283070e0185"},
{file = "types_Pillow-9.0.6-py3-none-any.whl", hash = "sha256:bd1e0a844fc718398aa265bf50fcad550fc520cc54f80e5ffeb7b3226b3cc507"},
]
+types-psycopg2 = [
+ {file = "types-psycopg2-2.9.9.tar.gz", hash = "sha256:4f9d4d52eeb343dc00fd5ed4f1513a8a5c18efba0a072eb82706d15cf4f20a2e"},
+ {file = "types_psycopg2-2.9.9-py3-none-any.whl", hash = "sha256:cec9291d4318ad70b407310f8304b3d40f6d0358f09870448f7a65e3027c80af"},
+]
types-pyopenssl = [
{file = "types-pyOpenSSL-22.0.0.tar.gz", hash = "sha256:d86dde7f6fe2f1ac9fe0b6282e489f649f480364bdaa9d6a4696d52505f4477e"},
{file = "types_pyOpenSSL-22.0.0-py3-none-any.whl", hash = "sha256:da685f57b864979f36df0157895139c8244ad4aad19b551f1678206fbad0108a"},
diff --git a/pyproject.toml b/pyproject.toml
index 9a515f48b7..335eb75c7b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -141,7 +141,8 @@ prometheus-client = ">=0.4.0"
attrs = ">=19.2.0,!=21.1.0"
netaddr = ">=0.7.18"
# We require 3.x or higher for type hints. Previous guard was >= 2.9.
-Jinja2 = ">=3.0.0"
+# Jinja2 3.1.0 removes the deprecated jinja2.Markup class, which we rely on.
+Jinja2 = ">=3.0.0,<3.1.0"
bleach = ">=1.4.3"
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
typing-extensions = ">=3.10.0"
@@ -249,6 +250,7 @@ types-bleach = ">=4.1.0"
types-jsonschema = ">=3.2.0"
types-opentracing = ">=2.4.2"
types-Pillow = ">=8.3.4"
+types-psycopg2 = ">=2.9.9"
types-pyOpenSSL = ">=20.0.7"
types-PyYAML = ">=5.4.10"
types-requests = ">=2.26.0"
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 046453e65f..685fa32b03 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -66,11 +66,15 @@ def cli():
./scripts-dev/release.py tag
- # ... wait for asssets to build ...
+ # ... wait for assets to build ...
./scripts-dev/release.py publish
./scripts-dev/release.py upload
+ # Optional: generate some nice links for the announcement
+
+ ./scripts-dev/release.py upload
+
If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the
`tag`/`publish` command, then a new draft release will be created/published.
"""
@@ -415,6 +419,41 @@ def upload():
)
+@cli.command()
+def announce():
+ """Generate markdown to announce the release."""
+
+ current_version, _, _ = parse_version_from_module()
+ tag_name = f"v{current_version}"
+
+ click.echo(
+ f"""
+Hi everyone. Synapse {current_version} has just been released.
+
+[notes](https://github.com/matrix-org/synapse/releases/tag/{tag_name}) |\
+[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \
+[debs](https://packages.matrix.org/debian/) | \
+[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)"""
+ )
+
+ if "rc" in tag_name:
+ click.echo(
+ """
+Announce the RC in
+- #homeowners:matrix.org (Synapse Announcements)
+- #synapse-dev:matrix.org"""
+ )
+ else:
+ click.echo(
+ """
+Announce the release in
+- #homeowners:matrix.org (Synapse Announcements), bumping the version in the topic
+- #synapse:matrix.org (Synapse Admins), bumping the version in the topic
+- #synapse-dev:matrix.org
+- #synapse-package-maintainers:matrix.org"""
+ )
+
+
def parse_version_from_module() -> Tuple[
version.Version, redbaron.RedBaron, redbaron.Node
]:
diff --git a/setup.py b/setup.py
new file mode 100755
index 0000000000..63da71ad7b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python
+
+# Copyright 2014-2017 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2017-2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Any, Dict
+
+from setuptools import Command, find_packages, setup
+
+here = os.path.abspath(os.path.dirname(__file__))
+
+
+# Some notes on `setup.py test`:
+#
+# Once upon a time we used to try to make `setup.py test` run `tox` to run the
+# tests. That's a bad idea for three reasons:
+#
+# 1: `setup.py test` is supposed to find out whether the tests work in the
+# *current* environmentt, not whatever tox sets up.
+# 2: Empirically, trying to install tox during the test run wasn't working ("No
+# module named virtualenv").
+# 3: The tox documentation advises against it[1].
+#
+# Even further back in time, we used to use setuptools_trial [2]. That has its
+# own set of issues: for instance, it requires installation of Twisted to build
+# an sdist (because the recommended mode of usage is to add it to
+# `setup_requires`). That in turn means that in order to successfully run tox
+# you have to have the python header files installed for whichever version of
+# python tox uses (which is python3 on recent ubuntus, for example).
+#
+# So, for now at least, we stick with what appears to be the convention among
+# Twisted projects, and don't attempt to do anything when someone runs
+# `setup.py test`; instead we direct people to run `trial` directly if they
+# care.
+#
+# [1]: http://tox.readthedocs.io/en/2.5.0/example/basic.html#integration-with-setup-py-test-command
+# [2]: https://pypi.python.org/pypi/setuptools_trial
+class TestCommand(Command):
+ def initialize_options(self):
+ pass
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ print(
+ """Synapse's tests cannot be run via setup.py. To run them, try:
+ PYTHONPATH="." trial tests
+"""
+ )
+
+
+def read_file(path_segments):
+ """Read a file from the package. Takes a list of strings to join to
+ make the path"""
+ file_path = os.path.join(here, *path_segments)
+ with open(file_path) as f:
+ return f.read()
+
+
+def exec_file(path_segments):
+ """Execute a single python file to get the variables defined in it"""
+ result: Dict[str, Any] = {}
+ code = read_file(path_segments)
+ exec(code, result)
+ return result
+
+
+version = exec_file(("synapse", "__init__.py"))["__version__"]
+dependencies = exec_file(("synapse", "python_dependencies.py"))
+long_description = read_file(("README.rst",))
+
+REQUIREMENTS = dependencies["REQUIREMENTS"]
+CONDITIONAL_REQUIREMENTS = dependencies["CONDITIONAL_REQUIREMENTS"]
+ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"]
+
+# Make `pip install matrix-synapse[all]` install all the optional dependencies.
+CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
+
+# Developer dependencies should not get included in "all".
+#
+# We pin black so that our tests don't start failing on new releases.
+CONDITIONAL_REQUIREMENTS["lint"] = [
+ "isort==5.7.0",
+ "black==21.12b0",
+ "flake8-comprehensions",
+ "flake8-bugbear==21.3.2",
+ "flake8",
+]
+
+CONDITIONAL_REQUIREMENTS["mypy"] = [
+ "mypy==0.931",
+ "mypy-zope==0.3.5",
+ "types-bleach>=4.1.0",
+ "types-jsonschema>=3.2.0",
+ "types-opentracing>=2.4.2",
+ "types-Pillow>=8.3.4",
+ "types-psycopg2>=2.9.9",
+ "types-pyOpenSSL>=20.0.7",
+ "types-PyYAML>=5.4.10",
+ "types-requests>=2.26.0",
+ "types-setuptools>=57.4.0",
+]
+
+# Dependencies which are exclusively required by unit test code. This is
+# NOT a list of all modules that are necessary to run the unit tests.
+# Tests assume that all optional dependencies are installed.
+#
+# parameterized_class decorator was introduced in parameterized 0.7.0
+CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
+
+CONDITIONAL_REQUIREMENTS["dev"] = (
+ CONDITIONAL_REQUIREMENTS["lint"]
+ + CONDITIONAL_REQUIREMENTS["mypy"]
+ + CONDITIONAL_REQUIREMENTS["test"]
+ + [
+ # The following are used by the release script
+ "click==7.1.2",
+ "redbaron==0.9.2",
+ "GitPython==3.1.14",
+ "commonmark==0.9.1",
+ "pygithub==1.55",
+ # The following are executed as commands by the release script.
+ "twine",
+ "towncrier",
+ ]
+)
+
+setup(
+ name="matrix-synapse",
+ version=version,
+ packages=find_packages(exclude=["tests", "tests.*"]),
+ description="Reference homeserver for the Matrix decentralised comms protocol",
+ install_requires=REQUIREMENTS,
+ extras_require=CONDITIONAL_REQUIREMENTS,
+ include_package_data=True,
+ zip_safe=False,
+ long_description=long_description,
+ long_description_content_type="text/x-rst",
+ python_requires="~=3.7",
+ entry_points={
+ "console_scripts": [
+ # Application
+ "synapse_homeserver = synapse.app.homeserver:main",
+ "synapse_worker = synapse.app.generic_worker:main",
+ "synctl = synapse._scripts.synctl:main",
+ # Scripts
+ "export_signing_key = synapse._scripts.export_signing_key:main",
+ "generate_config = synapse._scripts.generate_config:main",
+ "generate_log_config = synapse._scripts.generate_log_config:main",
+ "generate_signing_key = synapse._scripts.generate_signing_key:main",
+ "hash_password = synapse._scripts.hash_password:main",
+ "register_new_matrix_user = synapse._scripts.register_new_matrix_user:main",
+ "synapse_port_db = synapse._scripts.synapse_port_db:main",
+ "synapse_review_recent_signups = synapse._scripts.review_recent_signups:main",
+ "update_synapse_database = synapse._scripts.update_synapse_database:main",
+ ]
+ },
+ classifiers=[
+ "Development Status :: 5 - Production/Stable",
+ "Topic :: Communications :: Chat",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ ],
+ cmdclass={"test": TestCommand},
+)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e4dc04c0b4..0f75e7b9d4 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "metrics" and self.config.metrics.enable_metrics:
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+ metrics_resource: Resource = MetricsResource(RegistryProxy)
+ if compress:
+ metrics_resource = gz_wrap(metrics_resource)
+ resources[METRICS_PREFIX] = metrics_resource
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
@@ -348,6 +351,23 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
if config.server.gc_seconds:
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
+ if (
+ config.registration.enable_registration
+ and not config.registration.enable_registration_without_verification
+ ):
+ if (
+ not config.captcha.enable_registration_captcha
+ and not config.registration.registrations_require_3pid
+ and not config.registration.registration_requires_token
+ ):
+
+ raise ConfigError(
+ "You have enabled open registration without any verification. This is a known vector for "
+ "spam and abuse. If you would like to allow public registration, please consider adding email, "
+ "captcha, or token-based verification. Otherwise this check can be removed by setting the "
+ "`enable_registration_without_verification` config option to `true`."
+ )
+
hs = SynapseHomeServer(
config.server.server_name,
config=config,
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 06ccf15cd9..d7f2219f53 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -37,6 +37,12 @@ DEFAULT_CONFIG = """\
# 'txn_limit' gives the maximum number of transactions to run per connection
# before reconnecting. Defaults to 0, which means no limit.
#
+# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
+# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
+# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
+# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
+# https://wiki.postgresql.org/wiki/Locale_data_changes
+#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ea9b50fe97..40fb329a7f 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,6 +33,10 @@ class RegistrationConfig(Config):
str(config["disable_registration"])
)
+ self.enable_registration_without_verification = strtobool(
+ str(config.get("enable_registration_without_verification", False))
+ )
+
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
@@ -207,10 +211,18 @@ class RegistrationConfig(Config):
# Registration can be rate-limited using the parameters in the "Ratelimiting"
# section of this file.
- # Enable registration for new users.
+ # Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
+ # you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
+ # without any verification, you must also set `enable_registration_without_verification`, found below.
#
#enable_registration: false
+ # Enable registration without email or captcha verification. Note: this option is *not* recommended,
+ # as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
+ # unless `enable_registration` is also enabled.
+ #
+ #enable_registration_without_verification: true
+
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 49cd0a4f19..38de4b8000 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -676,6 +676,10 @@ class ServerConfig(Config):
):
raise ConfigError("'custom_template_directory' must be a string")
+ self.use_account_validity_in_account_status: bool = (
+ config.get("use_account_validity_in_account_status") or False
+ )
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 60904a55f5..cd80fcf9d1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,7 +21,6 @@ from typing import (
Awaitable,
Callable,
Collection,
- Dict,
List,
Optional,
Tuple,
@@ -31,7 +30,7 @@ from typing import (
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
@@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
-CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
@@ -383,7 +382,7 @@ class SpamChecker:
return True
- async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a0520068e0..7120062127 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase
if TYPE_CHECKING:
+ from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
- from synapse.storage.databases.main.relations import BundledAggregations
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 482bbdd867..c7400c737b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
Callable,
Collection,
Dict,
- Iterable,
List,
Optional,
Tuple,
@@ -577,10 +576,10 @@ class FederationServer(FederationBase):
async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
+ pdus: Collection[EventBase]
if event_id:
- pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
- room_id, event_id
- )
+ event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+ pdus = await self.store.get_events_as_list(event_ids)
else:
pdus = (await self.state.get_current_state(room_id)).values()
@@ -1093,7 +1092,7 @@ class FederationServer(FederationBase):
# has started processing).
while True:
async with lock:
- logger.info("handling received PDU: %s", event)
+ logger.info("handling received PDU in room %s: %s", room_id, event)
try:
with nested_logging_context(event.event_id):
await self._federation_event_handler.on_receive_pdu(
diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
index d5badf635b..c05a14304c 100644
--- a/synapse/handlers/account.py
+++ b/synapse/handlers/account.py
@@ -26,6 +26,10 @@ class AccountHandler:
self._main_store = hs.get_datastores().main
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()
+ self._use_account_validity_in_account_status = (
+ hs.config.server.use_account_validity_in_account_status
+ )
+ self._account_validity_handler = hs.get_account_validity_handler()
async def get_account_statuses(
self,
@@ -106,6 +110,13 @@ class AccountHandler:
"deactivated": userinfo.is_deactivated,
}
+ if self._use_account_validity_in_account_status:
+ status[
+ "org.matrix.expired"
+ ] = await self._account_validity_handler.is_user_expired(
+ user_id.to_string()
+ )
+
return status
async def _get_remote_account_statuses(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index db39aeabde..350ec9c03a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -950,54 +950,35 @@ class FederationHandler:
return event
- async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
- """Returns the state at the event. i.e. not including said event."""
-
- event = await self.store.get_event(event_id, check_room_id=room_id)
-
- state_groups = await self.state_store.get_state_groups(room_id, [event_id])
-
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = {(e.type, e.state_key): e for e in state}
-
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- prev_event = await self.store.get_event(prev_id)
- results[(event.type, event.state_key)] = prev_event
- else:
- del results[(event.type, event.state_key)]
-
- res = list(results.values())
- return res
- else:
- return []
-
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
+ if event.internal_metadata.outlier:
+ raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = state
+ # get_state_groups_ids should return exactly one result
+ assert len(state_groups) == 1
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- results[(event.type, event.state_key)] = prev_id
- else:
- results.pop((event.type, event.state_key), None)
+ state_map = next(iter(state_groups.values()))
- return list(results.values())
- else:
- return []
+ state_key = event.get_state_key()
+ if state_key is not None:
+ # the event was not rejected (get_event raises a NotFoundError for rejected
+ # events) so the state at the event should include the event itself.
+ assert (
+ state_map.get((event.type, state_key)) == event.event_id
+ ), "State at event did not include event itself"
+
+ # ... but we need the state *before* that event
+ if "replaces_state" in event.unsigned:
+ prev_id = event.unsigned["replaces_state"]
+ state_map[(event.type, state_key)] = prev_id
+ else:
+ del state_map[(event.type, state_key)]
+
+ return list(state_map.values())
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f9544fe7fb..1c4fb4360a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -493,6 +493,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
outlier: bool = False,
historical: bool = False,
@@ -527,6 +528,15 @@ class EventCreationHandler:
If non-None, prev_event_ids must also be provided.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is with insertion events which float at
+ the beginning of a historical batch and don't have any `prev_events` to
+ derive from; we add all of these state events as the explicit state so the
+ rest of the historical batch can inherit the same state and state_group.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
+
require_consent: Whether to check if the requester has
consented to the privacy policy.
@@ -612,6 +622,7 @@ class EventCreationHandler:
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
depth=depth,
)
@@ -772,6 +783,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
ratelimit: bool = True,
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
@@ -801,6 +813,14 @@ class EventCreationHandler:
based on the room state at the prev_events.
If non-None, prev_event_ids must also be provided.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is with insertion events which float at
+ the beginning of a historical batch and don't have any `prev_events` to
+ derive from; we add all of these state events as the explicit state so the
+ rest of the historical batch can inherit the same state and state_group.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
ratelimit: Whether to rate limit this send.
txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to
@@ -856,8 +876,10 @@ class EventCreationHandler:
requester,
event_dict,
txn_id=txn_id,
+ allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
outlier=outlier,
historical=historical,
depth=depth,
@@ -893,6 +915,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@@ -915,6 +938,15 @@ class EventCreationHandler:
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is with insertion events which float at
+ the beginning of a historical batch and don't have any `prev_events` to
+ derive from; we add all of these state events as the explicit state so the
+ rest of the historical batch can inherit the same state and state_group.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
+
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
@@ -922,31 +954,26 @@ class EventCreationHandler:
Returns:
Tuple of created event, context
"""
- # Strip down the auth_event_ids to only what we need to auth the event.
+ # Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
- full_state_ids_at_event = None
- if auth_event_ids is not None:
- # If auth events are provided, prev events must be also.
+ if state_event_ids is not None:
+ # Do a quick check to make sure that prev_event_ids is present to
+ # make the type-checking around `builder.build` happy.
# prev_event_ids could be an empty array though.
assert prev_event_ids is not None
- # Copy the full auth state before it stripped down
- full_state_ids_at_event = auth_event_ids.copy()
-
temp_event = await builder.build(
prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
+ auth_event_ids=state_event_ids,
depth=depth,
)
- auth_events = await self.store.get_events_as_list(auth_event_ids)
+ state_events = await self.store.get_events_as_list(state_event_ids)
# Create a StateMap[str]
- auth_event_state_map = {
- (e.type, e.state_key): e.event_id for e in auth_events
- }
- # Actually strip down and use the necessary auth events
+ state_map = {(e.type, e.state_key): e.event_id for e in state_events}
+ # Actually strip down and only use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
- current_state_ids=auth_event_state_map,
+ current_state_ids=state_map,
for_verification=False,
)
@@ -989,12 +1016,16 @@ class EventCreationHandler:
context = EventContext.for_outlier()
elif (
event.type == EventTypes.MSC2716_INSERTION
- and full_state_ids_at_event
+ and state_event_ids
and builder.internal_metadata.is_historical()
):
+ # Add explicit state to the insertion event so it has state to derive
+ # from even though it's floating with no `prev_events`. The rest of
+ # the batch can derive from this state and state_group.
+ #
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
- old_state = await self.store.get_events_as_list(full_state_ids_at_event)
+ old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
else:
context = await self.state.compute_event_context(event)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 41679f7f86..876b879483 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -134,6 +134,7 @@ class PaginationHandler:
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
+ self._relations_handler = hs.get_relations_handler()
self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
@@ -539,7 +540,9 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
- aggregations = await self.store.get_bundled_aggregations(events, user_id)
+ aggregations = await self._relations_handler.get_bundled_aggregations(
+ events, user_id
+ )
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6554c0d3c2..239b0aa744 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -336,12 +336,18 @@ class ProfileHandler:
"""Check that the size and content type of the avatar at the given MXC URI are
within the configured limits.
+ If the given `mxc` is empty, no checks are performed. (Users are always able to
+ unset their avatar.)
+
Args:
mxc: The MXC URI at which the avatar can be found.
Returns:
A boolean indicating whether the file can be allowed to be set as an avatar.
"""
+ if mxc == "":
+ return True
+
if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
return True
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 8e475475ad..73217d135d 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -12,21 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
+import attr
+from frozendict import frozendict
+
+from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
+from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken
+from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ # The latest event in the thread.
+ latest_event: EventBase
+ # The latest edit to the latest event in the thread.
+ latest_edit: Optional[EventBase]
+ # The total number of events in the thread.
+ count: int
+ # True if the current user has sent an event to the thread.
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
+ self._storage = hs.get_storage()
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
@@ -68,7 +105,8 @@ class RelationsHandler:
user_id = requester.user.to_string()
- await self._auth.check_user_in_room_or_world_readable(
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
@@ -95,6 +133,10 @@ class RelationsHandler:
[c["event_id"] for c in pagination_chunk.chunk]
)
+ events = await filter_events_for_client(
+ self._storage, user_id, events, is_peeking=(member_event_id is None)
+ )
+
now = self._clock.time_msec()
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
@@ -103,7 +145,7 @@ class RelationsHandler:
)
# The relations returned for the requested event do include their
# bundled aggregations.
- aggregations = await self._main_store.get_bundled_aggregations(
+ aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
@@ -115,3 +157,115 @@ class RelationsHandler:
return_value["original_event"] = original_event
return return_value
+
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase, user_id: str
+ ) -> Optional[BundledAggregations]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations = BundledAggregations()
+
+ annotations = await self._main_store.get_aggregation_groups_for_event(
+ event_id, room_id
+ )
+ if annotations.chunk:
+ aggregations.annotations = await annotations.to_dict(
+ cast("DataStore", self)
+ )
+
+ references = await self._main_store.get_relations_for_event(
+ event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations.references = await references.to_dict(cast("DataStore", self))
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self, events: Iterable[EventBase], user_id: str
+ ) -> Dict[str, BundledAggregations]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # De-duplicate events by ID to handle the same event requested multiple times.
+ #
+ # State events do not get bundled aggregations.
+ events_by_id = {
+ event.event_id: event for event in events if not event.is_state()
+ }
+
+ # event ID -> bundled aggregation in non-serialized form.
+ results: Dict[str, BundledAggregations] = {}
+
+ # Fetch other relations per event.
+ for event in events_by_id.values():
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ if event_result:
+ results[event.event_id] = event_result
+
+ # Fetch any edits (but not for redacted events).
+ edits = await self._main_store.get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Fetch thread summaries.
+ summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
+ # Only fetch participated for a limited selection based on what had
+ # summaries.
+ participated = await self._main_store.get_threads_participated(
+ [event_id for event_id, summary in summaries.items() if summary], user_id
+ )
+ for event_id, summary in summaries.items():
+ if summary:
+ thread_count, latest_thread_event, edit = summary
+ results.setdefault(
+ event_id, BundledAggregations()
+ ).thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ latest_edit=edit,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=participated[event_id],
+ )
+
+ return results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b9735631fc..092e185c99 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,8 +60,8 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
+from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
@@ -1118,6 +1118,7 @@ class RoomContextHandler:
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self._relations_handler = hs.get_relations_handler()
async def get_event_context(
self,
@@ -1190,7 +1191,7 @@ class RoomContextHandler:
event = filtered[0]
# Fetch the aggregations.
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index abbf7b7b27..a0255bd143 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -121,12 +121,11 @@ class RoomBatchHandler:
return create_requester(user_id, app_service=app_service)
- async def get_most_recent_auth_event_ids_from_event_id_list(
+ async def get_most_recent_full_state_ids_from_event_id_list(
self, event_ids: List[str]
) -> List[str]:
- """Find the most recent auth event ids (derived from state events) that
- allowed that message to be sent. We will use this as a base
- to auth our historical messages against.
+ """Find the most recent event_id and grab the full state at that event.
+ We will use this as a base to auth our historical messages against.
Args:
event_ids: List of event ID's to look at
@@ -136,38 +135,37 @@ class RoomBatchHandler:
"""
(
- most_recent_prev_event_id,
+ most_recent_event_id,
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
- most_recent_prev_event_id
+ most_recent_event_id
)
# List of state event ID's
- prev_state_ids = list(prev_state_map.values())
- auth_event_ids = prev_state_ids
+ full_state_ids = list(prev_state_map.values())
- return auth_event_ids
+ return full_state_ids
async def persist_state_events_at_start(
self,
state_events_at_start: List[JsonDict],
room_id: str,
- initial_auth_event_ids: List[str],
+ initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists
- them as floating state events which don't resolve into the current room state.
- They are floating because they reference a fake prev_event which doesn't connect
- to the normal DAG at all.
+ them in a floating state event chain which don't resolve into the current room
+ state. They are floating because they reference no prev_events and are marked
+ as outliers which disconnects them from the normal DAG.
Args:
state_events_at_start:
room_id: Room where you want the events persisted in.
- initial_auth_event_ids: These will be the auth_events for the first
- state event created. Each event created afterwards will be
- added to the list of auth events for the next state event
- created.
+ initial_state_event_ids:
+ The base set of state for the historical batch which the floating
+ state chain will derive from. This should probably be the state
+ from the `prev_event` defined by `/batch_send?prev_event_id=$abc`.
app_service_requester: The requester of an application service.
Returns:
@@ -176,7 +174,7 @@ class RoomBatchHandler:
assert app_service_requester.app_service
state_event_ids_at_start = []
- auth_event_ids = initial_auth_event_ids.copy()
+ state_event_ids = initial_state_event_ids.copy()
# Make the state events float off on their own by specifying no
# prev_events for the first one in the chain so we don't have a bunch of
@@ -189,9 +187,7 @@ class RoomBatchHandler:
)
logger.debug(
- "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
- state_event,
- auth_event_ids,
+ "RoomBatchSendEventRestServlet inserting state_event=%s", state_event
)
event_dict = {
@@ -217,16 +213,26 @@ class RoomBatchHandler:
room_id=room_id,
action=membership,
content=event_dict["content"],
+ # Mark as an outlier to disconnect it from the normal DAG
+ # and not show up between batches of history.
outlier=True,
historical=True,
- # Only the first event in the chain should be floating.
+ # Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
+ # Since each state event is marked as an outlier, the
+ # `EventContext.for_outlier()` won't have any `state_ids`
+ # set and therefore can't derive any state even though the
+ # prev_events are set. Also since the first event in the
+ # state chain is floating with no `prev_events`, it can't
+ # derive state from anywhere automatically. So we need to
+ # set some state explicitly.
+ #
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
+ state_event_ids=state_event_ids.copy(),
)
else:
# TODO: Add some complement tests that adds state that is not member joins
@@ -240,21 +246,31 @@ class RoomBatchHandler:
state_event["sender"], app_service_requester.app_service
),
event_dict,
+ # Mark as an outlier to disconnect it from the normal DAG
+ # and not show up between batches of history.
outlier=True,
historical=True,
- # Only the first event in the chain should be floating.
+ # Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
+ # Since each state event is marked as an outlier, the
+ # `EventContext.for_outlier()` won't have any `state_ids`
+ # set and therefore can't derive any state even though the
+ # prev_events are set. Also since the first event in the
+ # state chain is floating with no `prev_events`, it can't
+ # derive state from anywhere automatically. So we need to
+ # set some state explicitly.
+ #
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
+ state_event_ids=state_event_ids.copy(),
)
event_id = event.event_id
state_event_ids_at_start.append(event_id)
- auth_event_ids.append(event_id)
+ state_event_ids.append(event_id)
# Connect all the state in a floating chain
prev_event_ids_for_state_chain = [event_id]
@@ -265,7 +281,7 @@ class RoomBatchHandler:
events_to_create: List[JsonDict],
room_id: str,
inherited_depth: int,
- auth_event_ids: List[str],
+ initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> List[str]:
"""Create and persists all events provided sequentially. Handles the
@@ -281,8 +297,10 @@ class RoomBatchHandler:
room_id: Room where you want the events persisted in.
inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)).
- auth_event_ids: Define which events allow you to create the given
- event in the room.
+ initial_state_event_ids:
+ This is used to set explicit state for the insertion event at
+ the start of the historical batch since it's floating with no
+ prev_events to derive state from automatically.
app_service_requester: The requester of an application service.
Returns:
@@ -290,6 +308,11 @@ class RoomBatchHandler:
"""
assert app_service_requester.app_service
+ # We expect the first event in a historical batch to be an insertion event
+ assert events_to_create[0]["type"] == EventTypes.MSC2716_INSERTION
+ # We expect the last event in a historical batch to be an batch event
+ assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH
+
# Make the historical event chain float off on its own by specifying no
# prev_events for the first event in the chain which causes the HS to
# ask for the state at the start of the batch later.
@@ -321,11 +344,16 @@ class RoomBatchHandler:
ev["sender"], app_service_requester.app_service
),
event_dict,
- # Only the first event in the chain should be floating.
- # The rest should hang off each other in a chain.
+ # Only the first event (which is the insertion event) in the
+ # chain should be floating. The rest should hang off each other
+ # in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=event_dict.get("prev_events"),
- auth_event_ids=auth_event_ids,
+ # Since the first event (which is the insertion event) in the
+ # chain is floating with no `prev_events`, it can't derive state
+ # from anywhere automatically. So we need to set some state
+ # explicitly.
+ state_event_ids=initial_state_event_ids if index == 0 else None,
historical=True,
depth=inherited_depth,
)
@@ -343,10 +371,9 @@ class RoomBatchHandler:
)
logger.debug(
- "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
+ "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s",
event,
prev_event_ids,
- auth_event_ids,
)
events_to_persist.append((event, context))
@@ -376,12 +403,12 @@ class RoomBatchHandler:
room_id: str,
batch_id_to_connect_to: str,
inherited_depth: int,
- auth_event_ids: List[str],
+ initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> Tuple[List[str], str]:
"""
- Handles creating and persisting all of the historical events as well
- as insertion and batch meta events to make the batch navigable in the DAG.
+ Handles creating and persisting all of the historical events as well as
+ insertion and batch meta events to make the batch navigable in the DAG.
Args:
events_to_create: List of historical events to create in JSON
@@ -391,8 +418,13 @@ class RoomBatchHandler:
want this batch to connect to.
inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)).
- auth_event_ids: Define which events allow you to create the given
- event in the room.
+ initial_state_event_ids:
+ This is used to set explicit state for the insertion event at
+ the start of the historical batch since it's floating with no
+ prev_events to derive state from automatically. This should
+ probably be the state from the `prev_event` defined by
+ `/batch_send?prev_event_id=$abc` plus the outcome of
+ `persist_state_events_at_start`
app_service_requester: The requester of an application service.
Returns:
@@ -438,7 +470,7 @@ class RoomBatchHandler:
events_to_create=events_to_create,
room_id=room_id,
inherited_depth=inherited_depth,
- auth_event_ids=auth_event_ids,
+ initial_state_event_ids=initial_state_event_ids,
app_service_requester=app_service_requester,
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7cbc484b06..a33fa34aa8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -272,6 +272,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
@@ -298,6 +299,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is the historical `state_events_at_start`;
+ since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
+ have any `state_ids` set and therefore can't derive any state even though the
+ prev_events are set so we need to set them ourself via this argument.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
txn_id:
ratelimit:
@@ -353,6 +362,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
require_consent=require_consent,
outlier=outlier,
historical=historical,
@@ -456,6 +466,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -487,6 +498,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is the historical `state_events_at_start`;
+ since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
+ have any `state_ids` set and therefore can't derive any state even though the
+ prev_events are set so we need to set them ourself via this argument.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -526,6 +545,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
)
return result
@@ -548,6 +568,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -581,6 +602,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is the historical `state_events_at_start`;
+ since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
+ have any `state_ids` set and therefore can't derive any state even though the
+ prev_events are set so we need to set them ourself via this argument.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -708,6 +737,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
content=content,
require_consent=require_consent,
outlier=outlier,
@@ -932,6 +962,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
content=content,
require_consent=require_consent,
outlier=outlier,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aa16e417eb..30eddda65f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,6 +54,7 @@ class SearchHandler:
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
+ self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
@@ -354,7 +355,7 @@ class SearchHandler:
aggregations = None
if self._msc3666_enabled:
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c9d6a18bd7..6c569cfb1c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
+from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -269,6 +269,7 @@ class SyncHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
+ self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
@@ -638,8 +639,10 @@ class SyncHandler:
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
- bundled_aggregations = await self.store.get_bundled_aggregations(
- recents, sync_config.user.to_string()
+ bundled_aggregations = (
+ await self._relations_handler.get_bundled_aggregations(
+ recents, sync_config.user.to_string()
+ )
)
return TimelineBatch(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d27ed2be6a..048fd4bb82 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,8 +19,8 @@ import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo
-from synapse.types import JsonDict
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def search_users(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d735c1d461..a628faaf65 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -111,6 +111,7 @@ from synapse.types import (
StateMap,
UserID,
UserInfo,
+ UserProfile,
create_requester,
)
from synapse.util import Clock
@@ -150,6 +151,7 @@ __all__ = [
"EventBase",
"StateMap",
"ProfileInfo",
+ "UserProfile",
]
logger = logging.getLogger(__name__)
@@ -609,15 +611,18 @@ class ModuleApi:
localpart: str,
displayname: Optional[str] = None,
emails: Optional[List[str]] = None,
+ admin: bool = False,
) -> "defer.Deferred[str]":
"""Registers a new user with given localpart and optional displayname, emails.
Added in Synapse v1.2.0.
+ Changed in Synapse v1.56.0: add 'admin' argument to register the user as admin.
Args:
localpart: The localpart of the new user.
displayname: The displayname of the new user.
emails: Emails to bind to the new user.
+ admin: True if the user should be registered as a server admin.
Raises:
SynapseError if there is an error performing the registration. Check the
@@ -631,6 +636,7 @@ class ModuleApi:
localpart=localpart,
default_display_name=displayname,
bind_emails=emails or [],
+ admin=admin,
)
)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 030898e4d0..a402a3e403 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
+from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
@@ -292,7 +293,7 @@ def _condition_checker(
return True
-MemberMap = Dict[str, Tuple[str, str]]
+MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
@@ -306,7 +307,7 @@ class RulesForRoomData:
*only* include data, and not references to e.g. the data stores.
"""
- # event_id -> (user_id, state)
+ # event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict)
@@ -447,11 +448,10 @@ class RulesForRoom:
res = self.data.member_map.get(event_id, None)
if res:
- user_id, state = res
- if state == Membership.JOIN:
- rules = self.data.rules_by_user.get(user_id, None)
+ if res.membership == Membership.JOIN:
+ rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
- ret_rules_by_user[user_id] = rules
+ ret_rules_by_user[res.user_id] = rules
continue
# If a user has left a room we remove their push rule. If they
@@ -502,24 +502,26 @@ class RulesForRoom:
"""
sequence = self.data.sequence
- rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
-
- members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
+ members = await self.store.get_membership_from_event_ids(
+ member_event_ids.values()
+ )
- # If the event is a join event then it will be in current state evnts
+ # If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
- members[event_id] = (event.state_key, event.membership)
+ members[event_id] = EventIdMembership(
+ user_id=event.state_key, membership=event.membership
+ )
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
joined_user_ids = {
- user_id
- for user_id, membership in members.values()
- if membership == Membership.JOIN
+ entry.user_id
+ for entry in members.values()
+ if entry and entry.membership == Membership.JOIN
}
logger.debug("Joined: %r", joined_user_ids)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
new file mode 100644
index 0000000000..79ae06ce5d
--- /dev/null
+++ b/synapse/python_dependencies.py
@@ -0,0 +1,147 @@
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from typing import Set
+
+logger = logging.getLogger(__name__)
+
+
+# REQUIREMENTS is a simple list of requirement specifiers[1], and must be
+# installed. It is passed to setup() as install_requires in setup.py.
+#
+# CONDITIONAL_REQUIREMENTS is the optional dependencies, represented as a dict
+# of lists. The dict key is the optional dependency name and can be passed to
+# pip when installing. The list is a series of requirement specifiers[1] to be
+# installed when that optional dependency requirement is specified. It is passed
+# to setup() as extras_require in setup.py
+#
+# Note that these both represent runtime dependencies (and the versions
+# installed are checked at runtime).
+#
+# Also note that we replicate these constraints in the Synapse Dockerfile while
+# pre-installing dependencies. If these constraints are updated here, the same
+# change should be made in the Dockerfile.
+#
+# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
+
+REQUIREMENTS = [
+ # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0
+ "jsonschema>=3.0.0",
+ # frozendict 2.1.2 is broken on Debian 10: https://github.com/Marco-Sulla/python-frozendict/issues/41
+ "frozendict>=1,!=2.1.2",
+ "unpaddedbase64>=1.1.0",
+ "canonicaljson>=1.4.0",
+ # we use the type definitions added in signedjson 1.1.
+ "signedjson>=1.1.0",
+ "pynacl>=1.2.1",
+ "idna>=2.5",
+ # validating SSL certs for IP addresses requires service_identity 18.1.
+ "service_identity>=18.1.0",
+ # Twisted 18.9 introduces some logger improvements that the structured
+ # logger utilises
+ "Twisted>=18.9.0",
+ "treq>=15.1",
+ # Twisted has required pyopenssl 16.0 since about Twisted 16.6.
+ "pyopenssl>=16.0.0",
+ "pyyaml>=3.11",
+ "pyasn1>=0.1.9",
+ "pyasn1-modules>=0.0.7",
+ "bcrypt>=3.1.0",
+ "pillow>=5.4.0",
+ "sortedcontainers>=1.4.4",
+ "pymacaroons>=0.13.0",
+ "msgpack>=0.5.2",
+ "phonenumbers>=8.2.0",
+ # we use GaugeHistogramMetric, which was added in prom-client 0.4.0.
+ "prometheus_client>=0.4.0",
+ # we use `order`, which arrived in attrs 19.2.0.
+ # Note: 21.1.0 broke `/sync`, see #9936
+ "attrs>=19.2.0,!=21.1.0",
+ "netaddr>=0.7.18",
+ # Jinja2 3.1.0 removes the deprecated jinja2.Markup class, which we rely on.
+ "Jinja2<3.1.0",
+ "bleach>=1.4.3",
+ # We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
+ "typing-extensions>=3.10.0",
+ # We enforce that we have a `cryptography` version that bundles an `openssl`
+ # with the latest security patches.
+ "cryptography>=3.4.7",
+ # ijson 3.1.4 fixes a bug with "." in property names
+ "ijson>=3.1.4",
+ "matrix-common~=1.1.0",
+ # We need packaging.requirements.Requirement, added in 16.1.
+ "packaging>=16.1",
+]
+
+CONDITIONAL_REQUIREMENTS = {
+ "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
+ "postgres": [
+ # we use execute_values with the fetch param, which arrived in psycopg 2.8.
+ "psycopg2>=2.8 ; platform_python_implementation != 'PyPy'",
+ "psycopg2cffi>=2.8 ; platform_python_implementation == 'PyPy'",
+ "psycopg2cffi-compat==1.1 ; platform_python_implementation == 'PyPy'",
+ ],
+ "saml2": [
+ "pysaml2>=4.5.0",
+ ],
+ "oidc": ["authlib>=0.14.0"],
+ # systemd-python is necessary for logging to the systemd journal via
+ # `systemd.journal.JournalHandler`, as is documented in
+ # `contrib/systemd/log_config.yaml`.
+ "systemd": ["systemd-python>=231"],
+ "url_preview": ["lxml>=4.2.0"],
+ "sentry": ["sentry-sdk>=0.7.2"],
+ "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
+ "jwt": ["pyjwt>=1.6.4"],
+ # hiredis is not a *strict* dependency, but it makes things much faster.
+ # (if it is not installed, we fall back to slow code.)
+ "redis": ["txredisapi>=1.4.7", "hiredis"],
+ # Required to use experimental `caches.track_memory_usage` config option.
+ "cache_memory": ["pympler"],
+}
+
+ALL_OPTIONAL_REQUIREMENTS: Set[str] = set()
+
+for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
+ # Exclude systemd as it's a system-based requirement.
+ # Exclude lint as it's a dev-based requirement.
+ if name not in ["systemd"]:
+ ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
+
+
+# ensure there are no double-quote characters in any of the deps (otherwise the
+# 'pip install' incantation in DependencyException will break)
+for dep in itertools.chain(
+ REQUIREMENTS,
+ *CONDITIONAL_REQUIREMENTS.values(),
+):
+ if '"' in dep:
+ raise Exception(
+ "Dependency `%s` contains double-quote; use single-quotes instead" % (dep,)
+ )
+
+
+def list_requirements():
+ return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS)
+
+
+if __name__ == "__main__":
+ import sys
+
+ sys.stdout.writelines(req + "\n" for req in list_requirements())
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index 41315e4fd4..1ba850369a 100644
--- a/synapse/res/templates/sso_auth_account_details.html
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -130,13 +130,13 @@
</head>
<body>
<header>
- <h1>Choose your user name</h1>
- <p>This is required to create your account on {{ server_name }}, and you can't change this later.</p>
+ <h1>Create your account</h1>
+ <p>This is required. Continue to create your account on {{ server_name }}. You can't change this later.</p>
</header>
<main>
<form method="post" class="form__input" id="form">
<div class="username_input" id="username_input">
- <label for="field-username">Username</label>
+ <label for="field-username">Username (required)</label>
<div class="prefix">@</div>
<input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
<div class="postfix">:{{ server_name }}</div>
@@ -145,7 +145,7 @@
<input type="submit" value="Continue" class="primary-button">
{% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %}
<section class="idp-pick-details">
- <h2>{% if idp.idp_icon %}<img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>{% endif %}Information from {{ idp.idp_name }}</h2>
+ <h2>{% if idp.idp_icon %}<img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>{% endif %}Optional data from {{ idp.idp_name }}</h2>
{% if user_attributes.avatar_url %}
<label class="idp-detail idp-avatar" for="idp-avatar">
<div class="check-row">
diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js
index 3c45df9078..82438519a2 100644
--- a/synapse/res/templates/sso_auth_account_details.js
+++ b/synapse/res/templates/sso_auth_account_details.js
@@ -62,7 +62,7 @@ function validateUsername(username) {
usernameField.parentElement.classList.remove("invalid");
usernameOutput.classList.remove("error");
if (!username) {
- return reportError("Please provide a username");
+ return reportError("This is required. Please provide a username");
}
if (username.length > 255) {
return reportError("Too long, please choose something shorter");
diff --git a/synapse/res/templates/sso_footer.html b/synapse/res/templates/sso_footer.html
index 588a3d508d..b46e0d83fe 100644
--- a/synapse/res/templates/sso_footer.html
+++ b/synapse/res/templates/sso_footer.html
@@ -15,5 +15,5 @@
</g>
</g>
</svg>
- <p>An open network for secure, decentralized communication.<br>© 2021 The Matrix.org Foundation C.I.C.</p>
+ <p>An open network for secure, decentralized communication.<br>© 2022 The Matrix.org Foundation C.I.C.</p>
</footer>
\ No newline at end of file
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 762808a571..57c4773edc 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -32,6 +32,7 @@ from synapse.rest.client import (
knock,
login as v1_login,
logout,
+ mutual_rooms,
notifications,
openid,
password_policy,
@@ -49,7 +50,6 @@ from synapse.rest.client import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
- shared_rooms,
sync,
tags,
thirdparty,
@@ -132,4 +132,4 @@ class ClientRestResource(JsonResource):
admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable
- shared_rooms.register_servlets(hs, client_resource)
+ mutual_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/mutual_rooms.py
index e669fa7890..27bfaf0b29 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/mutual_rooms.py
@@ -28,13 +28,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class UserSharedRoomsServlet(RestServlet):
+class UserMutualRoomsServlet(RestServlet):
"""
- GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+ GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1
"""
PATTERNS = client_patterns(
- "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+ "/uk.half-shot.msc2666/user/mutual_rooms/(?P<user_id>[^/]*)",
releases=(), # This is an unstable feature
)
@@ -42,17 +42,19 @@ class UserSharedRoomsServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.user_directory_active = hs.config.server.update_user_directory
+ self.user_directory_search_enabled = (
+ hs.config.userdirectory.user_directory_search_enabled
+ )
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
- if not self.user_directory_active:
+ if not self.user_directory_search_enabled:
raise SynapseError(
code=400,
- msg="The user directory is disabled on this server. Cannot determine shared rooms.",
- errcode=Codes.FORBIDDEN,
+ msg="User directory searching is disabled. Cannot determine shared rooms.",
+ errcode=Codes.UNKNOWN,
)
UserID.from_string(user_id)
@@ -64,7 +66,8 @@ class UserSharedRoomsServlet(RestServlet):
msg="You cannot request a list of shared rooms with yourself",
errcode=Codes.FORBIDDEN,
)
- rooms = await self.store.get_shared_rooms_for_users(
+
+ rooms = await self.store.get_mutual_rooms_for_users(
requester.user.to_string(), user_id
)
@@ -72,4 +75,4 @@ class UserSharedRoomsServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- UserSharedRoomsServlet(hs).register(http_server)
+ UserMutualRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 8a06ab8c5f..47e152c8cc 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth()
async def on_GET(
@@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
if event:
# Ensure there are bundled aggregations available.
- aggregations = await self._store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string()
)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 0048973e59..0780485322 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -124,14 +124,14 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
# For the event we are inserting next to (`prev_event_ids_from_query`),
- # find the most recent auth events (derived from state events) that
- # allowed that message to be sent. We will use that as a base
- # to auth our historical messages against.
- auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list(
+ # find the most recent state events that allowed that message to be
+ # sent. We will use that as a base to auth our historical messages
+ # against.
+ state_event_ids = await self.room_batch_handler.get_most_recent_full_state_ids_from_event_id_list(
prev_event_ids_from_query
)
- if not auth_event_ids:
+ if not state_event_ids:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
@@ -148,13 +148,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
await self.room_batch_handler.persist_state_events_at_start(
state_events_at_start=body["state_events_at_start"],
room_id=room_id,
- initial_auth_event_ids=auth_event_ids,
+ initial_state_event_ids=state_event_ids,
app_service_requester=requester,
)
)
# Update our ongoing auth event ID list with all of the new state we
# just created
- auth_event_ids.extend(state_event_ids_at_start)
+ state_event_ids.extend(state_event_ids_at_start)
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query
@@ -196,7 +196,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
),
base_insertion_event_dict,
prev_event_ids=base_insertion_event_dict.get("prev_events"),
- auth_event_ids=auth_event_ids,
+ # Also set the explicit state here because we want to resolve
+ # any `state_events_at_start` here too. It's not strictly
+ # necessary to accomplish anything but if someone asks for the
+ # state at this point, we probably want to show them the
+ # historical state that was part of this batch.
+ state_event_ids=state_event_ids,
historical=True,
depth=inherited_depth,
)
@@ -212,7 +217,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
room_id=room_id,
batch_id_to_connect_to=batch_id_to_connect_to,
inherited_depth=inherited_depth,
- auth_event_ids=auth_event_ids,
+ initial_state_event_ids=state_event_ids,
app_service_requester=requester,
)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index a47d9bd01d..116c982ce6 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonMapping
from ._base import client_patterns
@@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
"""Searches for users in directory
Returns:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 9749f0c06e..367709a1a7 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -288,7 +288,7 @@ class LoggingTransaction:
"""
if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch # type: ignore
+ from psycopg2.extras import execute_batch
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
@@ -302,10 +302,18 @@ class LoggingTransaction:
rows (e.g. INSERTs).
"""
assert isinstance(self.database_engine, PostgresEngine)
- from psycopg2.extras import execute_values # type: ignore
+ from psycopg2.extras import execute_values
return self._do_execute(
- lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
+ # Type ignore: mypy is unhappy because if `x` is a 5-tuple, then there will
+ # be two values for `fetch`: one given positionally, and another given
+ # as a keyword argument. We might be able to fix this by
+ # - propagating the signature of psycopg2.extras.execute_values to this
+ # function, or
+ # - changing `*args: Any` to `values: T` for some appropriate T.
+ lambda *x: execute_values(self.txn, *x, fetch=fetch), # type: ignore[misc]
+ sql,
+ *args,
)
def execute(self, sql: str, *args: Any) -> None:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d7511d613..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ self._get_membership_from_event_id.invalidate((event_id,))
+
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f60aef180..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1745,6 +1745,13 @@ class PersistEventsStore:
(event.state_key,),
)
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ txn.call_after(
+ self.store._get_membership_from_event_id.invalidate,
+ (event.event_id,),
+ )
+
# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 3f6086050b..0aef121d83 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]:
# TODO: Pagination
- keyvalues = {"group_id": group_id}
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination
- def _get_rooms_in_group_txn(txn):
+ def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """
SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
@@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category
"""
- def _get_rooms_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_rooms_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
- async def get_group_categories(self, group_id):
+ async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
@@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_category(self, group_id, category_id):
+ async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category
- async def get_group_roles(self, group_id):
+ async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
@@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_role(self, group_id, role_id):
+ async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
- async def get_users_for_summary_by_role(self, group_id, include_private=False):
+ async def get_users_for_summary_by_role(
+ self, group_id: str, include_private: bool = False
+ ) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request
Returns:
([users], [roles])
"""
- def _get_users_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_users_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], JsonDict]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
- async def get_users_membership_info_in_group(self, group_id, user_id):
+ async def get_users_membership_info_in_group(
+ self, group_id: str, user_id: str
+ ) -> JsonDict:
"""Get a dict describing the membership of a user in a group.
Example if joined:
@@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc
"""
- def _get_users_membership_in_group_txn(txn):
+ def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
@@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
- async def get_attestations_need_renewals(self, valid_until_ms):
+ async def get_attestations_need_renewals(
+ self, valid_until_ms: int
+ ) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time"""
- def _get_attestations_need_renewals_txn(txn):
+ def _get_attestations_need_renewals_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
sql = """
SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ?
@@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
- async def get_remote_attestation(self, group_id, user_id):
+ async def get_remote_attestation(
+ self, group_id: str, user_id: str
+ ) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
@@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
- async def get_all_groups_for_user(self, user_id, now_token):
- def _get_all_groups_for_user_txn(txn):
+ async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+ def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
@@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
- async def get_groups_changes_for_user(self, user_id, from_token, to_token):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_entity_changed(
+ async def get_groups_changes_for_user(
+ self, user_id: str, from_token: int, to_token: int
+ ) -> List[JsonDict]:
+ has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token
)
if not has_changed:
return []
- def _get_groups_changes_for_user_txn(txn):
+ def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
@@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
last_id = int(last_id)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed:
return [], current_id, False
- def _get_all_groups_changes_txn(txn):
+ def _get_all_groups_changes_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
@@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [
- (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ]
+ updates = cast(
+ List[Tuple[int, tuple]],
+ [
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
+ for stream_id, group_id, user_id, gtype, content_json in txn
+ ],
+ )
limited = False
upto_token = current_id
@@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id,
"room_id": room_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_summary(
- self, group_id: str, room_id: str, category_id: str
+ self, group_id: str, room_id: str, category_id: Optional[str]
) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
@@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/update room category for group"""
- insertion_values = {}
- update_values = {"category_id": category_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/remove user role"""
- insertion_values = {}
- update_values = {"role_id": role_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
@@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
- ):
+ ) -> None:
"""Add (or update) user's entry in summary.
Args:
@@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id,
"user_id": user_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_user_from_summary(
- self, group_id: str, user_id: str, role_id: str
+ self, group_id: str, user_id: str, role_id: Optional[str]
) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
@@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server
"""
- def _add_user_to_group_txn(txn):
+ def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
table="group_users",
@@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
- def _remove_user_from_group_txn(txn):
+ def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_users",
@@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
- def _remove_room_from_group_txn(txn):
+ def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
@@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {}
- def _register_user_group_membership_txn(txn, next_id):
+ def _register_user_group_membership_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> int:
# TODO: Upsert?
self.db_pool.simple_delete_txn(
txn,
@@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
),
},
)
- self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+ self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join.
@@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- async with self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res
async def create_group(
- self, group_id, user_id, name, avatar_url, short_description, long_description
+ self,
+ group_id: str,
+ user_id: str,
+ name: str,
+ avatar_url: str,
+ short_description: str,
+ long_description: str,
) -> None:
await self.db_pool.simple_insert(
table="groups",
@@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
- async def update_group_profile(self, group_id, profile):
+ async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
@@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal",
)
- def get_group_stream_token(self):
- return self._group_updates_id_gen.get_current_token()
+ def get_group_stream_token(self) -> int:
+ return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
@@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete.
"""
- def _delete_group_txn(txn):
+ def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [
"groups",
"group_users",
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e9a0cdc6be..216622964a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
+ LoggingTransaction,
make_in_list_sql_clause,
)
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email
@@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users
"""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users
sql = """
SELECT COUNT(*)
@@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
"""
txn.execute(sql)
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
- def _count_users_by_service(txn):
+ def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
@@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
txn.execute(sql)
- result = txn.fetchall()
+ result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction(
@@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
@wrap_as_background_process("reap_monthly_active_users")
- async def reap_monthly_active_users(self):
+ async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale
entries exist.
"""
- def _reap_users(txn, reserved_users):
+ def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
"""
Args:
reserved_users (tuple): reserved users to preserve
@@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
- self._invalidate_all_cache_and_stream(
+ self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active
)
- self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
- def _initialise_reserved_users(self, txn, threepids):
+ def _initialise_reserved_users(
+ self, txn: LoggingTransaction, threepids: List[dict]
+ ) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
+ txn:
+ threepids: List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
@@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
- def upsert_monthly_active_user_txn(self, txn, user_id):
+ def upsert_monthly_active_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> None:
"""Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it
@@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,)
)
- async def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = await self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest:
return
is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c4869d64e6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
)
import attr
-from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
- # The latest event in the thread.
- latest_event: EventBase
- # The latest edit to the latest event in the thread.
- latest_edit: Optional[EventBase]
- # The total number of events in the thread.
- count: int
- # True if the current user has sent an event to the thread.
- current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
- """
- The bundled aggregations for an event.
-
- Some values require additional processing during serialization.
- """
-
- annotations: Optional[JsonDict] = None
- references: Optional[JsonDict] = None
- replace: Optional[EventBase] = None
- thread: Optional[_ThreadAggregation] = None
-
- def __bool__(self) -> bool:
- return bool(self.annotations or self.references or self.replace or self.thread)
-
-
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
- async def _get_applicable_edits(
+ async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
@@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
- async def _get_thread_summaries(
+ async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
- latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+ latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
- async def _get_threads_participated(
+ async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
@@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- async def _get_bundled_aggregation_for_event(
- self, event: EventBase, user_id: str
- ) -> Optional[BundledAggregations]:
- """Generate bundled aggregations for an event.
-
- Note that this does not use a cache, but depends on cached methods.
-
- Args:
- event: The event to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- The bundled aggregations for an event, if bundled aggregations are
- enabled and the event can have bundled aggregations.
- """
-
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return None
-
- event_id = event.event_id
- room_id = event.room_id
-
- # The bundled aggregations to include, a mapping of relation type to a
- # type-specific value. Some types include the direct return type here
- # while others need more processing during serialization.
- aggregations = BundledAggregations()
-
- annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
- if annotations.chunk:
- aggregations.annotations = await annotations.to_dict(
- cast("DataStore", self)
- )
-
- references = await self.get_relations_for_event(
- event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
- )
- if references.chunk:
- aggregations.references = await references.to_dict(cast("DataStore", self))
-
- # Store the bundled aggregations in the event metadata for later use.
- return aggregations
-
- async def get_bundled_aggregations(
- self, events: Iterable[EventBase], user_id: str
- ) -> Dict[str, BundledAggregations]:
- """Generate bundled aggregations for events.
-
- Args:
- events: The iterable of events to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- A map of event ID to the bundled aggregation for the event. Not all
- events may have bundled aggregations in the results.
- """
- # De-duplicate events by ID to handle the same event requested multiple times.
- #
- # State events do not get bundled aggregations.
- events_by_id = {
- event.event_id: event for event in events if not event.is_state()
- }
-
- # event ID -> bundled aggregation in non-serialized form.
- results: Dict[str, BundledAggregations] = {}
-
- # Fetch other relations per event.
- for event in events_by_id.values():
- event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result:
- results[event.event_id] = event_result
-
- # Fetch any edits (but not for redacted events).
- edits = await self._get_applicable_edits(
- [
- event_id
- for event_id, event in events_by_id.items()
- if not event.internal_metadata.is_redacted()
- ]
- )
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
-
- # Fetch thread summaries.
- summaries = await self._get_thread_summaries(events_by_id.keys())
- # Only fetch participated for a limited selection based on what had
- # summaries.
- participated = await self._get_threads_participated(summaries.keys(), user_id)
- for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event, edit = summary
- results.setdefault(
- event_id, BundledAggregations()
- ).thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- latest_edit=edit,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=participated[event_id],
- )
-
- return results
-
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bef675b845..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+ """Returned by `get_membership_from_event_ids`"""
+
+ user_id: str
+ membership: str
+
+
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
- desc="_get_membership_from_event_ids",
+ desc="_get_joined_profiles_from_event_ids",
)
return {
@@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
+ @cached(max_entries=5000)
+ async def _get_membership_from_event_id(
+ self, member_event_id: str
+ ) -> Optional[EventIdMembership]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+ )
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
- ) -> List[dict]:
- """Get user_id and membership of a set of event IDs."""
+ ) -> Dict[str, Optional[EventIdMembership]]:
+ """Get user_id and membership of a set of event IDs.
+
+ Returns:
+ Mapping from event ID to `EventIdMembership` if the event is a
+ membership event, otherwise the value is None.
+ """
- return await self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids",
)
+ return {
+ row["event_id"]: EventIdMembership(
+ membership=row["membership"], user_id=row["user_id"]
+ )
+ for row in rows
+ }
+
async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..0595df01d3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
cast,
)
+from typing_extensions import TypedDict
+
from synapse.api.errors import StoreError
if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+ JsonDict,
+ UserProfile,
+ get_domain_from_id,
+ get_localpart_from_id,
+)
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
+class SearchResult(TypedDict):
+ limited: bool
+ results: List[UserProfile]
+
+
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- async def get_shared_rooms_for_users(
+ async def get_mutual_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
"""
@@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
- def _get_shared_rooms_for_users_txn(
+ def _get_mutual_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
@@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return rows
rows = await self.db_pool.runInteraction(
- "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+ "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
)
return {row["room_id"] for row in rows}
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ results = cast(
+ List[UserProfile],
+ await self.db_pool.execute(
+ "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ ),
)
limited = len(results) > limit
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9abc02046e..afb7d5054d 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
- import psycopg2 # type: ignore
+ import psycopg2
return PostgresEngine(psycopg2, database_config)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 808342fafb..e8d29e2870 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+ self.config = database_config
@property
def single_threaded(self) -> bool:
return False
+ def get_db_locale(self, txn):
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+ return collation, ctype
+
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
+ allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 100000:
@@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
"See docs/postgres.md for more information." % (rows[0][0],)
)
- txn.execute(
- "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
- )
- collation, ctype = txn.fetchone()
+ collation, ctype = self.get_db_locale(txn)
if collation != "C":
logger.warning(
- "Database has incorrect collation of %r. Should be 'C'\n"
- "See docs/postgres.md for more information.",
+ "Database has incorrect collation of %r. Should be 'C'",
collation,
)
+ if not allow_unsafe_locale:
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect collation of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information. You can override this check by"
+ "setting 'allow_unsafe_locale' to true in the database config.",
+ collation,
+ )
if ctype != "C":
- logger.warning(
- "Database has incorrect ctype of %r. Should be 'C'\n"
- "See docs/postgres.md for more information.",
- ctype,
- )
+ if not allow_unsafe_locale:
+ logger.warning(
+ "Database has incorrect ctype of %r. Should be 'C'",
+ ctype,
+ )
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect ctype of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information. You can override this check by"
+ "setting 'allow_unsafe_locale' to true in the database config.",
+ ctype,
+ )
def check_new_database(self, txn):
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
- txn.execute(
- "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
- )
- collation, ctype = txn.fetchone()
+ collation, ctype = self.get_db_locale(txn)
errors = []
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 7d543fdbe0..b402922817 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
# Check if any of the changes that we don't have events for are joins.
if events_to_check:
- rows = await self.main_store.get_membership_from_event_ids(events_to_check)
- is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+ members = await self.main_store.get_membership_from_event_ids(
+ events_to_check
+ )
+ is_still_joined = any(
+ member and member.membership == Membership.JOIN
+ for member in members.values()
+ )
if is_still_joined:
return True
@@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
- rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+ members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
- row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+ member.user_id
+ for member in members.values()
+ if member and member.membership == Membership.JOIN
)
return False
diff --git a/synapse/types.py b/synapse/types.py
index 53be3583a0..5ce2a5b0a5 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import (
import attr
from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
+from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now.
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
+# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
+# Useful when you have a TypedDict which isn't going to be mutated and you don't want
+# to cast to JsonDict everywhere.
+JsonMapping = Mapping[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
@@ -791,3 +796,9 @@ class UserInfo:
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
+
+
+class UserProfile(TypedDict):
+ user_id: str
+ display_name: Optional[str]
+ avatar_url: Optional[str]
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 12cd804939..66f1da7502 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -128,6 +128,19 @@ def _incorrect_version(
)
+def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str:
+ if extra:
+ return (
+ f"Synapse {VERSION} needs {requirement} for {extra}, "
+ f"but can't determine {requirement.name}'s version"
+ )
+ else:
+ return (
+ f"Synapse {VERSION} needs {requirement}, "
+ f"but can't determine {requirement.name}'s version"
+ )
+
+
def check_requirements(extra: Optional[str] = None) -> None:
"""Check Synapse's dependencies are present and correctly versioned.
@@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None:
deps_unfulfilled.append(requirement.name)
errors.append(_not_installed(requirement, extra))
else:
+ if dist.version is None:
+ # This shouldn't happen---it suggests a borked virtualenv. (See #12223)
+ # Try to give a vaguely helpful error message anyway.
+ # Type-ignore: the annotations don't reflect reality: see
+ # https://github.com/python/typeshed/issues/7513
+ # https://bugs.python.org/issue47060
+ deps_unfulfilled.append(requirement.name) # type: ignore[unreachable]
+ errors.append(_no_reported_version(requirement, extra))
+
# We specify prereleases=True to allow prereleases such as RCs.
- if not requirement.specifier.contains(dist.version, prereleases=True):
+ elif not requirement.specifier.contains(dist.version, prereleases=True):
deps_unfulfilled.append(requirement.name)
errors.append(_incorrect_version(requirement, dist.version, extra))
diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py
index 17a84d20d8..2acdb6ac61 100644
--- a/tests/config/test_registration_config.py
+++ b/tests/config/test_registration_config.py
@@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+import synapse.app.homeserver
from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
-from tests.unittest import TestCase
+from tests.config.utils import ConfigFileTestCase
from tests.utils import default_config
-class RegistrationConfigTestCase(TestCase):
+class RegistrationConfigTestCase(ConfigFileTestCase):
def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
"""
session_lifetime should logically be larger than, or at least as large as,
@@ -76,3 +78,19 @@ class RegistrationConfigTestCase(TestCase):
HomeServerConfig().parse_config_dict(
{"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
)
+
+ def test_refuse_to_start_if_open_registration_and_no_verification(self):
+ self.generate_config()
+ self.add_lines_to_config(
+ [
+ " ",
+ "enable_registration: true",
+ "registrations_require_3pid: []",
+ "enable_registration_captcha: false",
+ "registration_requires_token: false",
+ ]
+ )
+
+ # Test that allowing open registration without verification raises an error
+ with self.assertRaises(ConfigError):
+ synapse.app.homeserver.setup(["-c", self.config_file])
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a267228846..a54aa29cf1 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.cas import CasResponse
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
class CasHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
cas_config = {
@@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_cas_handler()
@@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return hs
- def test_map_cas_user_to_user(self):
+ def test_map_cas_user_to_user(self) -> None:
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_existing_user(self):
+ def test_map_cas_user_to_existing_user(self) -> None:
"""Existing users can log in with CAS account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_invalid_localpart(self):
+ def test_map_cas_user_to_invalid_localpart(self) -> None:
"""CAS automaps invalid characters to base-64 encoding."""
# stub out the auth handler
@@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_required_attributes(self):
+ def test_required_attributes(self) -> None:
"""The required attributes must be met from the CAS response."""
# stub out the auth handler
@@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have any department.
- cas_response = CasResponse("test_user", {"userGroup": "staff"})
+ cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e8b4e39d1a..89078fc637 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import List, cast
from unittest import TestCase
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
@@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import create_requester
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
@@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._event_auth_handler = hs.get_event_auth_handler()
return hs
- def test_exchange_revoked_invite(self):
+ def test_exchange_revoked_invite(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
- def test_rejected_message_event_state(self):
+ def test_rejected_message_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected non-state events.
@@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_rejected_state_event_state(self):
+ def test_rejected_state_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected state events.
@@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_backfill_with_many_backward_extremities(self):
+ def test_backfill_with_many_backward_extremities(self) -> None:
"""
Check that we can backfill with many backward extremities.
The goal is to make sure that when we only use a portion
@@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- def test_backfill_floating_outlier_membership_auth(self):
+ def test_backfill_floating_outlier_membership_auth(self) -> None:
"""
As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating
@@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
for ae in auth_events
]
- self.handler.federation_client.get_event_auth = get_event_auth
+ self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
@@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_invite_by_user_ratelimit(self):
+ def test_invite_by_user_ratelimit(self) -> None:
"""Tests that invites from federation to a particular user are
actually rate-limited.
"""
@@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
exc=LimitExceededError,
)
- def _build_and_send_join_event(self, other_server, other_user, room_id):
+ def _build_and_send_join_event(
+ self, other_server: str, other_user: str, room_id: str
+ ) -> EventBase:
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
)
@@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
class EventFromPduTestCase(TestCase):
- def test_valid_json(self):
+ def test_valid_json(self) -> None:
"""Valid JSON should be turned into an event."""
ev = event_from_pdu_json(
{
@@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
self.assertIsInstance(ev, EventBase)
- def test_invalid_numbers(self):
+ def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
-(2 ** 53),
@@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
RoomVersions.V6,
)
- def test_invalid_nested(self):
+ def test_invalid_nested(self) -> None:
"""List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError):
event_from_pdu_json(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 6ddec9ecf1..b2ed9cbe37 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
# Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
- presence_states = [(ps.user_id, ps.state) for ps in presence_states]
+ presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out.
# They should be identical.
- self.assertEqual(presence_states, db_presence_states)
+ self.assertEqual(presence_states_compare, db_presence_states)
class PresenceTimeoutTestCase(unittest.TestCase):
@@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)
@@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
def _set_presencestate_with_status_msg(
- self, user_id: str, state: PresenceState, status_msg: Optional[str]
+ self, user_id: str, state: str, status_msg: Optional[str]
):
"""Set a PresenceState and status_msg and check the result.
Args:
user_id: User for that the status is to be set.
- PresenceState: The new PresenceState.
+ state: The new PresenceState.
status_msg: Status message that is to be set.
"""
self.get_success(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 08733a9f2d..1ec105c373 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -268,6 +268,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50})
+ def test_avatar_constraints_allow_empty_avatar_url(self) -> None:
+ """An empty avatar is always permitted."""
+ res = self.get_success(self.handler.check_avatar_size_and_mime_type(""))
+ self.assertTrue(res)
+
+ @unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found.
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f91a80b9fa..ffd5c4cb93 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -18,11 +18,14 @@ from typing import Dict
from unittest.mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -42,7 +45,9 @@ ROOM_ID = "a-room"
OTHER_ROOM_ID = "another-room"
-def _expect_edu_transaction(edu_type, content, origin="test"):
+def _expect_edu_transaction(
+ edu_type: str, content: JsonDict, origin: str = "test"
+) -> JsonDict:
return {
"origin": origin,
"origin_server_ts": 1000000,
@@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
}
-def _make_edu_transaction_json(edu_type, content):
+def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id: str, user_id: str) -> None:
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
return None
hs.get_auth().check_user_in_room = check_user_in_room
- async def check_host_in_room(room_id, server_name):
+ async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
- def get_joined_hosts_for_room(room_id):
+ def get_joined_hosts_for_room(room_id: str):
return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- async def get_users_in_room(room_id):
+ async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room
@@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
lambda *args, **kwargs: make_awaitable(None)
)
- def test_started_typing_local(self):
+ def test_started_typing_local(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
@override_config({"send_federation": True})
- def test_started_typing_remote_send(self):
+ def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.get_success(
@@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
try_trailing_slash_on_400=True,
)
- def test_started_typing_remote_recv(self):
+ def test_started_typing_remote_recv(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- def test_started_typing_remote_recv_not_in_room(self):
+ def test_started_typing_remote_recv_not_in_room(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[1], 0)
@override_config({"send_federation": True})
- def test_stopped_typing(self):
+ def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]
# Gut-wrenching
@@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
)
- def test_typing_timeout(self):
+ def test_typing_timeout(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index c3f20f9692..10dd94b549 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -86,6 +86,16 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
+ def test_can_register_admin_user(self):
+ user_id = self.get_success(
+ self.register_user(
+ "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
+ )
+ )
+ found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ self.assertEqual(found_user.user_id.to_string(), user_id)
+ self.assertIdentical(found_user.is_admin, True)
+
def test_get_userinfo_by_id(self):
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 6691e07128..ba158f5d93 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,15 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException
from synapse.rest.client import login, push_rule, receipts, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["start_pushers"] = True
return config
- def make_homeserver(self, reactor, clock):
- self.push_attempts: List[tuple[Deferred, str, dict]] = []
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.push_attempts: List[Tuple[Deferred, str, dict]] = []
m = Mock()
@@ -56,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase):
return hs
- def test_invalid_configuration(self):
+ def test_invalid_configuration(self) -> None:
"""Invalid push configurations should be rejected."""
# Register the user who gets notified
user_id = self.register_user("user", "pass")
@@ -68,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
token_id = user_tuple.token_id
- def test_data(data):
+ def test_data(data: Optional[JsonDict]) -> None:
self.get_failure(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
@@ -95,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase):
# A url with an incorrect path isn't accepted.
test_data({"url": "http://example.com/foo"})
- def test_sends_http(self):
+ def test_sends_http(self) -> None:
"""
The HTTP pusher will send pushes for each message to a HTTP endpoint
when configured to do so.
@@ -200,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
- def test_sends_high_priority_for_encrypted(self):
+ def test_sends_high_priority_for_encrypted(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to an encrypted message.
@@ -321,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase):
)
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
- def test_sends_high_priority_for_one_to_one_only(self):
+ def test_sends_high_priority_for_one_to_one_only(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message in a one-to-one room.
@@ -404,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_sends_high_priority_for_mention(self):
+ def test_sends_high_priority_for_mention(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message containing the user's display name.
@@ -480,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_sends_high_priority_for_atroom(self):
+ def test_sends_high_priority_for_atroom(self) -> None:
"""
The HTTP pusher will send pushes at high priority if they correspond
to a message that contains @room.
@@ -563,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
- def test_push_unread_count_group_by_room(self):
+ def test_push_unread_count_group_by_room(self) -> None:
"""
The HTTP pusher will group unread count by number of unread rooms.
"""
@@ -576,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase):
self._check_push_attempt(6, 1)
@override_config({"push": {"group_unread_count_by_room": False}})
- def test_push_unread_count_message_count(self):
+ def test_push_unread_count_message_count(self) -> None:
"""
The HTTP pusher will send the total unread message count.
"""
@@ -589,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase):
# last read receipt
self._check_push_attempt(6, 3)
- def _test_push_unread_count(self):
+ def _test_push_unread_count(self) -> None:
"""
Tests that the correct unread count appears in sent push notifications
@@ -681,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room_id, body="HELLO???", tok=other_access_token)
- def _advance_time_and_make_push_succeed(self, expected_push_attempts):
+ def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None:
self.pump()
self.push_attempts[expected_push_attempts - 1][0].callback({})
@@ -708,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase):
expected_unread_count_last_push,
)
- def _send_read_request(self, access_token, message_event_id, room_id):
+ def _send_read_request(
+ self, access_token: str, message_event_id: str, room_id: str
+ ) -> None:
# Now set the user's read receipt position to the first event
#
# This will actually trigger a new notification to be sent out so that
@@ -748,7 +754,7 @@ class HTTPPusherTests(HomeserverTestCase):
return user_id, access_token
- def test_dont_notify_rule_overrides_message(self):
+ def test_dont_notify_rule_overrides_message(self) -> None:
"""
The override push rule will suppress notification
"""
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 3849beb9d6..5dba187076 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Dict, Optional, Union
import frozendict
@@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.types import JsonDict
from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def _get_evaluator(self, content):
+ def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
room_member_count = 0
sender_power_level = 0
- power_levels = {}
+ power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
)
- def test_display_name(self):
+ def test_display_name(self) -> None:
"""Check for a matching display name in the body of the event."""
evaluator = self._get_evaluator({"body": "foo bar baz"})
@@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
def _assert_matches(
- self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+ self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
def _assert_not_matches(
- self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+ self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None:
evaluator = self._get_evaluator(content)
self.assertFalse(
evaluator.matches(condition, "@user:test", "display_name"), msg
)
- def test_event_match_body(self):
+ def test_event_match_body(self) -> None:
"""Check that event_match conditions on content.body work as expected"""
# if the key is `content.body`, the pattern matches substrings.
@@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
r"? after \ should match any character",
)
- def test_event_match_non_body(self):
+ def test_event_match_non_body(self) -> None:
"""Check that event_match conditions on other keys work as expected"""
# if the key is anything other than 'content.body', the pattern must match the
@@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
- def test_no_body(self):
+ def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
@@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
}
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
- def test_invalid_body(self):
+ def test_invalid_body(self) -> None:
"""A non-string body should not break the evaluator."""
condition = {
"kind": "contains_display_name",
@@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
evaluator = self._get_evaluator({"body": body})
self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
- def test_tweaks_for_actions(self):
+ def test_tweaks_for_actions(self) -> None:
"""
This tests the behaviour of tweaks_for_actions.
"""
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index a60ea0a563..bef911d5df 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", True)
+ @override_config({"max_avatar_size": 1234})
+ def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None:
+ """Check we can erase a user whose avatar is the empty string.
+
+ Reproduces #12257.
+ """
+ # Patch `self.other_user` to have an empty string as their avatar.
+ self.get_success(self.store.set_profile_avatar_url("user", ""))
+
+ # Check we can still erase them.
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"erase": True},
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self._is_erased("@user:test", True)
+
def test_deactivate_user_erase_false(self) -> None:
"""
Test deactivating a user and set `erase` to `false`
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[users[2]],
)
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_no_account_validity(self) -> None:
+ """Tests that if we decide to include account validity in the response but no
+ account validity 'is_user_expired' callback is provided, we default to marking all
+ users as not expired.
+ """
+ user = self.register_user("someuser", "password")
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": False,
+ },
+ },
+ expected_failures=[],
+ )
+
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_account_validity_expired(self) -> None:
+ """Test that if we decide to include account validity in the response and the user
+ is expired, we return the correct info.
+ """
+ user = self.register_user("someuser", "password")
+
+ async def is_expired(user_id: str) -> bool:
+ # We can't blindly say everyone is expired, otherwise the request to get the
+ # account status will fail.
+ return UserID.from_string(user_id).localpart == "someuser"
+
+ self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+ is_expired
+ )
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": True,
+ },
+ },
+ expected_failures=[],
+ )
+
def _test_status(
self,
users: Optional[List[str]],
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 3818b7b14b..7b7d283bb6 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -14,7 +14,7 @@
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.rest.client import login, room, shared_rooms
+from synapse.rest.client import login, mutual_rooms, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -22,16 +22,16 @@ from tests import unittest
from tests.server import FakeChannel
-class UserSharedRoomsTest(unittest.HomeserverTestCase):
+class UserMutualRoomsTest(unittest.HomeserverTestCase):
"""
- Tests the UserSharedRoomsServlet.
+ Tests the UserMutualRoomsServlet.
"""
servlets = [
login.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
- shared_rooms.register_servlets,
+ mutual_rooms.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_user_directory_handler()
- def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
+ def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
"GET",
- "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
% other_user,
access_token=token,
)
@@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
A room should show up in the shared list of rooms between two users
if it is public.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is private.
"""
- self._check_shared_rooms_with(
+ self._check_mutual_rooms_with(
room_one_is_public=False, room_two_is_public=False
)
@@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
The shared room list between two users should contain both public and private
rooms.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
- def _check_shared_rooms_with(
+ def _check_mutual_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool
) -> None:
"""Checks that shared public or private rooms between two users appear in
@@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
# Check shared rooms from user1's perspective.
# We should see the one room in common
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room_id_one)
@@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room_id_two, user=u2, tok=u2_token)
# Check shared rooms again. We should now see both rooms.
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 2)
for room_id_id in channel.json_body["joined"]:
@@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room)
@@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.leave(room, user=u1, tok=u1_token)
# Check user1's view of shared rooms with user2
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
# Check user2's view of shared rooms with user1
- channel = self._get_shared_rooms(u2_token, u1)
+ channel = self._get_mutual_rooms(u2_token, u1)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index f3741b3001..fe97a0b3dd 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,12 +15,12 @@
import itertools
import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
from synapse.server import HomeServer
@@ -155,6 +155,16 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, channel.json_body)
return channel.json_body["chunk"]
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
@@ -291,202 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
- @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
- def test_bundled_aggregations(self) -> None:
- """
- Test that annotations, references, and threads get correctly bundled.
-
- Note that this doesn't test against /relations since only thread relations
- get bundled via that API. See test_aggregation_get_event_for_thread.
-
- See test_edit for a similar test for edits.
- """
- # Setup by sending a variety of relations.
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- reply_1 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- reply_2 = channel.json_body["event_id"]
-
- self._send_relation(RelationTypes.THREAD, "m.room.test")
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- thread_2 = channel.json_body["event_id"]
-
- def assert_bundle(event_json: JsonDict) -> None:
- """Assert the expected values of the bundled aggregations."""
- relations_dict = event_json["unsigned"].get("m.relations")
-
- # Ensure the fields are as expected.
- self.assertCountEqual(
- relations_dict.keys(),
- (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.THREAD,
- ),
- )
-
- # Check the values of each field.
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- relations_dict[RelationTypes.ANNOTATION],
- )
-
- self.assertEqual(
- {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- relations_dict[RelationTypes.REFERENCE],
- )
-
- self.assertEqual(
- 2,
- relations_dict[RelationTypes.THREAD].get("count"),
- )
- self.assertTrue(
- relations_dict[RelationTypes.THREAD].get("current_user_participated")
- )
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
- },
- "event_id": thread_2,
- "sender": self.user_id,
- "type": "m.room.test",
- },
- relations_dict[RelationTypes.THREAD].get("latest_event"),
- )
-
- # Request the event directly.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body)
-
- # Request the room messages.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
- # Request the room context.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"])
-
- # Request sync.
- channel = self.make_request("GET", "/sync", access_token=self.user_token)
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
- # Request search.
- channel = self.make_request(
- "POST",
- "/search",
- # Search term matches the parent message.
- content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- chunk = [
- result["result"]
- for result in channel.json_body["search_categories"]["room_events"][
- "results"
- ]
- ]
- assert_bundle(self._find_event_in_chunk(chunk))
-
- def test_aggregation_get_event_for_annotation(self) -> None:
- """Test that annotations do not get bundled aggregations included
- when directly requested.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- annotation_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
- )
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{annotation_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
- def test_aggregation_get_event_for_thread(self) -> None:
- """Test that threads get bundled aggregations included when directly requested."""
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- thread_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
- )
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{thread_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- # It should also be included when the entire thread is requested.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- thread_message = channel.json_body["chunk"][0]
- self.assertEqual(
- thread_message["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -796,7 +610,7 @@ class RelationsTestCase(BaseRelationsTestCase):
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
@@ -836,7 +650,7 @@ class RelationsTestCase(BaseRelationsTestCase):
edit_event_id = channel.json_body["event_id"]
# Edit the edit event.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -912,16 +726,6 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- return event
-
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
@@ -981,34 +785,6 @@ class RelationsTestCase(BaseRelationsTestCase):
[annotation_event_id_good, thread_event_id],
)
- def test_bundled_aggregations_with_filter(self) -> None:
- """
- If "unsigned" is an omitted field (due to filtering), adding the bundled
- aggregations should not break.
-
- Note that the spec allows for a server to return additional fields beyond
- what is specified.
- """
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-
- # Note that the sync filter does not include "unsigned" as a field.
- filter = urllib.parse.quote_plus(
- b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
- )
- channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Ensure the timeline is limited, find the parent event.
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- parent_event = self._find_event_in_chunk(room_timeline["events"])
-
- # Ensure there's bundled aggregations on it.
- self.assertIn("unsigned", parent_event)
- self.assertIn("m.relations", parent_event["unsigned"])
-
class RelationPaginationTestCase(BaseRelationsTestCase):
def test_basic_paginate_relations(self) -> None:
@@ -1255,7 +1031,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
idx += 1
# Also send a different type of reaction so that we test we don't see it
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
prev_token = ""
found_event_ids: List[str] = []
@@ -1291,6 +1067,341 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
self.assertEqual(found_event_ids, expected_event_ids)
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+ """
+ See RelationsTestCase.test_edit for a similar test for edits.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+ """
+
+ def _test_bundled_aggregations(
+ self,
+ relation_type: str,
+ assertion_callable: Callable[[JsonDict], None],
+ expected_db_txn_for_event: int,
+ ) -> None:
+ """
+ Makes requests to various endpoints which should include bundled aggregations
+ and then calls an assertion function on the bundled aggregations.
+
+ Args:
+ relation_type: The field to search for in the `m.relations` field in unsigned.
+ assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+ for relation-specific assertions.
+ expected_db_txn_for_event: The number of database transactions which
+ are expected for a call to /event/.
+ """
+
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+
+ # Ensure the fields are as expected.
+ self.assertCountEqual(relations_dict.keys(), (relation_type,))
+ assertion_callable(relations_dict[relation_type])
+
+ # Request the event directly.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
+
+ # Request sync.
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_annotation(self) -> None:
+ """
+ Test that annotations get correctly bundled.
+ """
+ # Setup by sending a variety of relations.
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_reference(self) -> None:
+ """
+ Test that references get correctly bundled.
+ """
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_thread(self) -> None:
+ """
+ Test that threads get correctly bundled.
+ """
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertTrue(bundled_aggregations.get("current_user_participated"))
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ },
+ bundled_aggregations.get("latest_event"),
+ )
+
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+ def test_aggregation_get_event_for_annotation(self) -> None:
+ """Test that annotations do not get bundled aggregations included
+ when directly requested.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ annotation_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{annotation_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+ def test_aggregation_get_event_for_thread(self) -> None:
+ """Test that threads get bundled aggregations included when directly requested."""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{thread_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ # It should also be included when the entire thread is requested.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+ thread_message = channel.json_body["chunk"][0]
+ self.assertEqual(
+ thread_message["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ def test_bundled_aggregations_with_filter(self) -> None:
+ """
+ If "unsigned" is an omitted field (due to filtering), adding the bundled
+ aggregations should not break.
+
+ Note that the spec allows for a server to return additional fields beyond
+ what is specified.
+ """
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+
+ # Note that the sync filter does not include "unsigned" as a field.
+ filter = urllib.parse.quote_plus(
+ b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # Ensure the timeline is limited, find the parent event.
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ parent_event = self._find_event_in_chunk(room_timeline["events"])
+
+ # Ensure there's bundled aggregations on it.
+ self.assertIn("unsigned", parent_event)
+ self.assertIn("m.relations", parent_event["unsigned"])
+
+
+class RelationIgnoredUserTestCase(BaseRelationsTestCase):
+ """Relations sent from an ignored user should be ignored."""
+
+ def _test_ignored_user(
+ self, allowed_event_ids: List[str], ignored_event_ids: List[str]
+ ) -> None:
+ """
+ Fetch the relations and ensure they're all there, then ignore user2, and
+ repeat.
+ """
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids)
+
+ # Ignore user2 and re-do the requests.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user2_id: {}}},
+ )
+ )
+
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids)
+
+ def test_annotation(self) -> None:
+ """Annotations should ignore"""
+ # Send 2 from us, 2 from the to be ignored user.
+ allowed_event_ids = []
+ ignored_event_ids = []
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="a",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="c",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_reference(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_thread(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+
class RelationRedactionTestCase(BaseRelationsTestCase):
"""
Test the behaviour of relations when the parent or child event is redacted.
diff --git a/tests/server.py b/tests/server.py
index 82990c2eb9..6ce2a17bf4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
ITransport,
)
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import (
+ AccumulatingProtocol,
+ MemoryReactor,
+ MemoryReactorClock,
+)
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest
+from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
@@ -88,18 +93,19 @@ class TimedOutException(Exception):
"""
-@attr.s
+@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
"""
- site = attr.ib(type=Union[Site, "FakeSite"])
- _reactor = attr.ib()
- result = attr.ib(type=dict, default=attr.Factory(dict))
- _ip = attr.ib(type=str, default="127.0.0.1")
+ site: Union[Site, "FakeSite"]
+ _reactor: MemoryReactor
+ result: dict = attr.Factory(dict)
+ _ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
+ resource_usage: Optional[ContextResourceUsage] = None
@property
def json_body(self):
@@ -168,6 +174,8 @@ class FakeChannel:
def requestDone(self, _self):
self.result["done"] = True
+ if isinstance(_self, SynapseRequest):
+ self.resource_usage = _self.logcontext.get_resource_usage()
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 5cf18b690e..fd619b64d4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -17,8 +17,12 @@ from unittest.mock import Mock
import yaml
from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
@@ -26,7 +30,7 @@ from tests.unittest import override_config
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
@@ -39,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
self.store = self.hs.get_datastores().main
- async def update(self, progress, count):
+ async def update(self, progress: JsonDict, count: int) -> int:
duration_ms = 10
await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
@@ -51,7 +55,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
return count
- def test_do_background_update(self):
+ def test_do_background_update(self) -> None:
# the time we claim it takes to update one item when running the update
duration_ms = 10
@@ -80,7 +84,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update
# we should now get run with a much bigger number of items to update
- async def update(progress, count):
+ async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count,
@@ -110,7 +114,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_default_batch_set_by_config(self):
+ def test_background_update_default_batch_set_by_config(self) -> None:
"""
Test that the background update is run with the default_batch_size set by the config
"""
@@ -133,7 +137,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# on the first call, we should get run with the default background update size specified in the config
self.update_handler.assert_called_once_with({"my_key": 1}, 20)
- def test_background_update_default_sleep_behavior(self):
+ def test_background_update_default_sleep_behavior(self) -> None:
"""
Test default background update behavior, which is to sleep
"""
@@ -147,7 +151,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update
self.update_handler.reset_mock()
- self.updates.start_doing_background_updates(),
+ self.updates.start_doing_background_updates()
# 2: advance the reactor less than the default sleep duration (1000ms)
self.reactor.pump([0.5])
@@ -167,7 +171,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_sleep_set_in_config(self):
+ def test_background_update_sleep_set_in_config(self) -> None:
"""
Test that changing the sleep time in the config changes how long it sleeps
"""
@@ -181,7 +185,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update
self.update_handler.reset_mock()
- self.updates.start_doing_background_updates(),
+ self.updates.start_doing_background_updates()
# 2: advance the reactor less than the configured sleep duration (500ms)
self.reactor.pump([0.45])
@@ -201,7 +205,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_disabling_background_update_sleep(self):
+ def test_disabling_background_update_sleep(self) -> None:
"""
Test that disabling sleep in the config results in bg update not sleeping
"""
@@ -215,7 +219,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update
self.update_handler.reset_mock()
- self.updates.start_doing_background_updates(),
+ self.updates.start_doing_background_updates()
# 2: advance the reactor very little
self.reactor.pump([0.025])
@@ -230,7 +234,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_duration_set_in_config(self):
+ def test_background_update_duration_set_in_config(self) -> None:
"""
Test that the desired duration set in the config is used in determining batch size
"""
@@ -254,7 +258,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with 500ms as the
# desired duration
- async def update(progress, count):
+ async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count,
@@ -275,7 +279,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
"""
)
)
- def test_background_update_min_batch_set_in_config(self):
+ def test_background_update_min_batch_set_in_config(self) -> None:
"""
Test that the minimum batch size set in the config is used
"""
@@ -290,7 +294,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
# Run the update with the long-running update item
- async def update(progress, count):
+ async def update_long(progress: JsonDict, count: int) -> int:
await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
await self.store.db_pool.runInteraction(
@@ -301,7 +305,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
)
return count
- self.update_handler.side_effect = update
+ self.update_handler.side_effect = update_long
self.update_handler.reset_mock()
res = self.get_success(
self.updates.do_next_background_update(False),
@@ -311,25 +315,25 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with minimum batch size
# as the first items took a very long time
- async def update(progress, count):
+ async def update_short(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2})
self.assertEqual(count, 5)
await self.updates._end_background_update("test_update")
return count
- self.update_handler.side_effect = update
+ self.update_handler.side_effect = update_short
self.get_success(self.updates.do_next_background_update(False))
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
)
- self.update_deferred = Deferred()
+ self.update_deferred: Deferred[int] = Deferred()
self.update_handler = Mock(return_value=self.update_deferred)
self.updates.register_background_update_handler(
"test_update", self.update_handler
@@ -358,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
),
)
- def test_controller(self):
+ def test_controller(self) -> None:
store = self.hs.get_datastores().main
self.get_success(
store.db_pool.simple_insert(
@@ -368,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
)
# Set the return value for the context manager.
- enter_defer = Deferred()
+ enter_defer: Deferred[int] = Deferred()
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
# Start the background update.
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 6ac4b93f98..395396340b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -13,9 +13,13 @@
# limitations under the License.
from typing import List, Optional
-from synapse.storage.database import DatabasePool
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
- def _insert_rows(self, instance_name: str, number: int):
+ def _insert_rows(self, instance_name: str, number: int) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
@@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
@@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
- def test_empty(self):
+ def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible
current positions.
"""
@@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# The table is empty so we expect an empty map for positions
self.assertEqual(id_gen.get_positions(), {})
- def test_single_instance(self):
+ def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
correctly.
"""
@@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
@@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
- def test_out_of_order_finish(self):
+ def test_out_of_order_finish(self) -> None:
"""Test that IDs persisted out of order are correctly handled"""
# Prefill table with 7 rows written by 'master'
@@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
- def test_multi_instance(self):
+ def test_multi_instance(self) -> None:
"""Test that reads and writes from multiple processes are handled
correctly.
"""
@@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
@@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# ... but calling `get_next` on the second instance should give a unique
# stream ID
- async def _get_next_async():
+ async def _get_next_async2() -> None:
async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
@@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.get_positions(), {"first": 3, "second": 7}
)
- self.get_success(_get_next_async())
+ self.get_success(_get_next_async2())
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
@@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
- def test_get_next_txn(self):
+ def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master'
@@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- def _get_next_txn(txn):
+ def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
@@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
- def test_get_persisted_upto_position(self):
+ def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
@@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
- def test_get_persisted_upto_position_get_next(self):
+ def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
@@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
- def test_restart_during_out_of_order_persistence(self):
+ def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out
of order updates are handled correctly.
"""
@@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
- def test_writer_config_change(self):
+ def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works."""
self._insert_row_with_id("first", 3)
@@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Check that we get a sane next stream ID with this new config.
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen_3.get_next() as stream_id:
self.assertEqual(stream_id, 6)
@@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
- def test_sequence_consistency(self):
+ def test_sequence_consistency(self) -> None:
"""Test that we error out if the table and sequence diverges."""
# Prefill with some rows
@@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
- def _insert_row(self, instance_name: str, stream_id: int):
+ def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id."""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
@@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
- def test_single_instance(self):
+ def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
correctly.
"""
id_gen = self._create_id_generator()
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id:
self._insert_row("master", stream_id)
@@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
- async def _get_next_async2():
+ async def _get_next_async2() -> None:
async with id_gen.get_next_mult(3) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)
@@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
- def test_multiple_instance(self):
+ def test_multiple_instance(self) -> None:
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
- async def _get_next_async():
+ async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
@@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
- async def _get_next_async2():
+ async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
@@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- def _setup_db(self, txn):
+ def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
@@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
from the postgres sequence.
"""
- def _insert(txn):
+ def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
@@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- def test_load_existing_stream(self):
+ def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
new file mode 100644
index 0000000000..ba53c22818
--- /dev/null
+++ b/tests/storage/test_unsafe_locale.py
@@ -0,0 +1,46 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import MagicMock, patch
+
+from synapse.storage.database import make_conn
+from synapse.storage.engines._base import IncorrectDatabaseSetup
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class UnsafeLocaleTest(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ @patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale")
+ def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None:
+ mock_db_locale.return_value = ("B", "B")
+ database = self.hs.get_datastores().databases[0]
+
+ db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+ with self.assertRaises(IncorrectDatabaseSetup):
+ database.engine.check_database(db_conn)
+ with self.assertRaises(IncorrectDatabaseSetup):
+ database.engine.check_new_database(db_conn)
+ db_conn.close()
+
+ def test_safe_locale(self) -> None:
+ database = self.hs.get_datastores().databases[0]
+
+ db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+ with db_conn.cursor() as txn:
+ res = database.engine.get_db_locale(txn)
+ self.assertEqual(res, ("C", "C"))
+ db_conn.close()
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 38e9f58ac6..5d1aa025d1 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -12,7 +12,7 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
- def __init__(self, version: str):
+ def __init__(self, version: object):
self._version = version
@property
@@ -30,6 +30,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4")
+distribution_with_no_version = DummyDistribution(None)
# could probably use stdlib TestCase --- no need for twisted here
@@ -67,6 +68,18 @@ class TestDependencyChecker(TestCase):
# should not raise
check_requirements()
+ def test_version_reported_as_none(self) -> None:
+ """Complain if importlib.metadata.version() returns None.
+
+ This shouldn't normally happen, but it was seen in the wild (#12223).
+ """
+ with patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1"],
+ ):
+ with self.mock_installed_package(distribution_with_no_version):
+ self.assertRaises(DependencyException, check_requirements)
+
def test_checks_ignore_dev_dependencies(self) -> None:
"""Bot generic and per-extra checks should ignore dev dependencies."""
with patch(
diff --git a/tests/utils.py b/tests/utils.py
index ef99c72e0b..f6b1d60371 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,13 +15,8 @@
import atexit
import os
-from unittest.mock import Mock, patch
-from urllib import parse as urlparse
-
-from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -187,111 +182,6 @@ def mock_getRawHeaders(headers=None):
return getRawHeaders
-# This is a mock /resource/ not an entire server
-class MockHttpResource:
- def __init__(self, prefix=""):
- self.callbacks = [] # 3-tuple of method/pattern/function
- self.prefix = prefix
-
- def trigger_get(self, path):
- return self.trigger(b"GET", path, None)
-
- @patch("twisted.web.http.Request")
- @defer.inlineCallbacks
- def trigger(
- self, http_method, path, content, mock_request, federation_auth_origin=None
- ):
- """Fire an HTTP event.
-
- Args:
- http_method : The HTTP method
- path : The HTTP path
- content : The HTTP body
- mock_request : Mocked request to pass to the event so it can get
- content.
- federation_auth_origin (bytes|None): domain to authenticate as, for federation
- Returns:
- A tuple of (code, response)
- Raises:
- KeyError If no event is found which will handle the path.
- """
- path = self.prefix + path
-
- # annoyingly we return a twisted http request which has chained calls
- # to get at the http content, hence mock it here.
- mock_content = Mock()
- config = {"read.return_value": content}
- mock_content.configure_mock(**config)
- mock_request.content = mock_content
-
- mock_request.method = http_method.encode("ascii")
- mock_request.uri = path.encode("ascii")
-
- mock_request.getClientIP.return_value = "-"
-
- headers = {}
- if federation_auth_origin is not None:
- headers[b"Authorization"] = [
- b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
- ]
- mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
-
- # return the right path if the event requires it
- mock_request.path = path
-
- # add in query params to the right place
- try:
- mock_request.args = urlparse.parse_qs(path.split("?")[1])
- mock_request.path = path.split("?")[0]
- path = mock_request.path
- except Exception:
- pass
-
- if isinstance(path, bytes):
- path = path.decode("utf8")
-
- for (method, pattern, func) in self.callbacks:
- if http_method != method:
- continue
-
- matcher = pattern.match(path)
- if matcher:
- try:
- args = [urlparse.unquote(u) for u in matcher.groups()]
-
- (code, response) = yield defer.ensureDeferred(
- func(mock_request, *args)
- )
- return code, response
- except CodeMessageException as e:
- return e.code, cs_error(e.msg, code=e.errcode)
-
- raise KeyError("No event can handle %s" % path)
-
- def register_paths(self, method, path_patterns, callback, servlet_name):
- for path_pattern in path_patterns:
- self.callbacks.append((method, path_pattern, callback))
-
-
-class MockKey:
- alg = "mock_alg"
- version = "mock_version"
- signature = b"\x9a\x87$"
-
- @property
- def verify_key(self):
- return self
-
- def sign(self, message):
- return self
-
- def verify(self, message, sig):
- assert sig == b"\x9a\x87$"
-
- def encode(self):
- return b"<fake_encoded_key>"
-
-
class MockClock:
now = 1000
|