diff --git a/CHANGES.md b/CHANGES.md
index 6d4bd23e4e..74b8e1df87 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,96 @@
+Synapse 1.18.0rc1 (2020-07-27)
+==============================
+
+Features
+--------
+
+- Include room states on invite events that are sent to application services. Contributed by @Sorunome. ([\#6455](https://github.com/matrix-org/synapse/issues/6455))
+- Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel. ([\#7613](https://github.com/matrix-org/synapse/issues/7613), [\#7953](https://github.com/matrix-org/synapse/issues/7953))
+- Add experimental support for running multiple federation sender processes. ([\#7798](https://github.com/matrix-org/synapse/issues/7798))
+- Add the option to validate the `iss` and `aud` claims for JWT logins. ([\#7827](https://github.com/matrix-org/synapse/issues/7827))
+- Add support for handling registration requests across multiple client reader workers. ([\#7830](https://github.com/matrix-org/synapse/issues/7830))
+- Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#7842](https://github.com/matrix-org/synapse/issues/7842))
+- Allow email subjects to be customised through Synapse's configuration. ([\#7846](https://github.com/matrix-org/synapse/issues/7846))
+- Add the ability to re-activate an account from the admin API. ([\#7847](https://github.com/matrix-org/synapse/issues/7847), [\#7908](https://github.com/matrix-org/synapse/issues/7908))
+- Add experimental support for running multiple pusher workers. ([\#7855](https://github.com/matrix-org/synapse/issues/7855))
+- Add experimental support for moving typing off master. ([\#7869](https://github.com/matrix-org/synapse/issues/7869), [\#7959](https://github.com/matrix-org/synapse/issues/7959))
+- Report CPU metrics to prometheus for time spent processing replication commands. ([\#7879](https://github.com/matrix-org/synapse/issues/7879))
+- Support oEmbed for media previews. ([\#7920](https://github.com/matrix-org/synapse/issues/7920))
+- Abort federation requests where the client disconnects before the ratelimiter expires. ([\#7930](https://github.com/matrix-org/synapse/issues/7930))
+- Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work. ([\#7931](https://github.com/matrix-org/synapse/issues/7931))
+
+
+Bugfixes
+--------
+
+- Fix detection of out of sync remote device lists when receiving events from remote users. ([\#7815](https://github.com/matrix-org/synapse/issues/7815))
+- Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain. ([\#7817](https://github.com/matrix-org/synapse/issues/7817))
+- Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0. ([\#7822](https://github.com/matrix-org/synapse/issues/7822))
+- Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails. ([\#7829](https://github.com/matrix-org/synapse/issues/7829))
+- Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`. ([\#7844](https://github.com/matrix-org/synapse/issues/7844))
+- Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0. ([\#7850](https://github.com/matrix-org/synapse/issues/7850))
+- Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation. ([\#7854](https://github.com/matrix-org/synapse/issues/7854))
+- Fix a bug which allowed empty rooms to be rejoined over federation. ([\#7859](https://github.com/matrix-org/synapse/issues/7859))
+- Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers. ([\#7866](https://github.com/matrix-org/synapse/issues/7866))
+- Fix a long standing bug where the tracing of async functions with opentracing was broken. ([\#7872](https://github.com/matrix-org/synapse/issues/7872), [\#7961](https://github.com/matrix-org/synapse/issues/7961))
+- Fix "TypeError in `synapse.notifier`" exceptions. ([\#7880](https://github.com/matrix-org/synapse/issues/7880))
+- Fix deprecation warning due to invalid escape sequences. ([\#7895](https://github.com/matrix-org/synapse/issues/7895))
+
+
+Updates to the Docker image
+---------------------------
+
+- Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196. ([\#7839](https://github.com/matrix-org/synapse/issues/7839))
+
+
+Improved Documentation
+----------------------
+
+- Provide instructions on using `register_new_matrix_user` via docker. ([\#7885](https://github.com/matrix-org/synapse/issues/7885))
+- Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation. ([\#7889](https://github.com/matrix-org/synapse/issues/7889))
+- Reorder database paragraphs to promote postgres over sqlite. ([\#7933](https://github.com/matrix-org/synapse/issues/7933))
+- Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md). ([\#7934](https://github.com/matrix-org/synapse/issues/7934))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric. ([\#7878](https://github.com/matrix-org/synapse/issues/7878))
+- Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim. ([\#7888](https://github.com/matrix-org/synapse/issues/7888))
+
+
+Internal Changes
+----------------
+
+- Switch parts of the codebase from `simplejson` to the standard library `json`. ([\#7802](https://github.com/matrix-org/synapse/issues/7802))
+- Add type hints to the http server code and remove an unused parameter. ([\#7813](https://github.com/matrix-org/synapse/issues/7813))
+- Add type hints to synapse.api.errors module. ([\#7820](https://github.com/matrix-org/synapse/issues/7820))
+- Ensure that calls to `json.dumps` are compatible with the standard library json. ([\#7836](https://github.com/matrix-org/synapse/issues/7836))
+- Remove redundant `retry_on_integrity_error` wrapper for event persistence code. ([\#7848](https://github.com/matrix-org/synapse/issues/7848))
+- Consistently use `db_to_json` to convert from database values to JSON objects. ([\#7849](https://github.com/matrix-org/synapse/issues/7849))
+- Convert various parts of the codebase to async/await. ([\#7851](https://github.com/matrix-org/synapse/issues/7851), [\#7860](https://github.com/matrix-org/synapse/issues/7860), [\#7868](https://github.com/matrix-org/synapse/issues/7868), [\#7871](https://github.com/matrix-org/synapse/issues/7871), [\#7873](https://github.com/matrix-org/synapse/issues/7873), [\#7874](https://github.com/matrix-org/synapse/issues/7874), [\#7884](https://github.com/matrix-org/synapse/issues/7884), [\#7912](https://github.com/matrix-org/synapse/issues/7912), [\#7935](https://github.com/matrix-org/synapse/issues/7935), [\#7939](https://github.com/matrix-org/synapse/issues/7939), [\#7942](https://github.com/matrix-org/synapse/issues/7942), [\#7944](https://github.com/matrix-org/synapse/issues/7944))
+- Add support for handling registration requests across multiple client reader workers. ([\#7853](https://github.com/matrix-org/synapse/issues/7853))
+- Small performance improvement in typing processing. ([\#7856](https://github.com/matrix-org/synapse/issues/7856))
+- The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100. ([\#7858](https://github.com/matrix-org/synapse/issues/7858))
+- Optimise queueing of inbound replication commands. ([\#7861](https://github.com/matrix-org/synapse/issues/7861))
+- Add some type annotations to `HomeServer` and `BaseHandler`. ([\#7870](https://github.com/matrix-org/synapse/issues/7870))
+- Clean up `PreserveLoggingContext`. ([\#7877](https://github.com/matrix-org/synapse/issues/7877))
+- Change "unknown room version" logging from 'error' to 'warning'. ([\#7881](https://github.com/matrix-org/synapse/issues/7881))
+- Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`. ([\#7882](https://github.com/matrix-org/synapse/issues/7882))
+- Return an empty body for OPTIONS requests. ([\#7886](https://github.com/matrix-org/synapse/issues/7886))
+- Fix typo in generated config file. Contributed by @ThiefMaster. ([\#7890](https://github.com/matrix-org/synapse/issues/7890))
+- Import ABC from `collections.abc` for Python 3.10 compatibility. ([\#7892](https://github.com/matrix-org/synapse/issues/7892))
+- Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
+ and `get_previous_frame` from `synapse.logging.utils` module. ([\#7897](https://github.com/matrix-org/synapse/issues/7897))
+- Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI. ([\#7914](https://github.com/matrix-org/synapse/issues/7914))
+- Use Element CSS and logo in notification emails when app name is Element. ([\#7919](https://github.com/matrix-org/synapse/issues/7919))
+- Optimisation to /sync handling: skip serializing the response if the client has already disconnected. ([\#7927](https://github.com/matrix-org/synapse/issues/7927))
+- When a client disconnects, don't log it as 'Error processing request'. ([\#7928](https://github.com/matrix-org/synapse/issues/7928))
+- Add debugging to `/sync` response generation (disabled by default). ([\#7929](https://github.com/matrix-org/synapse/issues/7929))
+- Update comments that refer to Deferreds for async functions. ([\#7945](https://github.com/matrix-org/synapse/issues/7945))
+- Simplify error handling in federation handler. ([\#7950](https://github.com/matrix-org/synapse/issues/7950))
+
+
Synapse 1.17.0 (2020-07-13)
===========================
diff --git a/changelog.d/6455.feature b/changelog.d/6455.feature
deleted file mode 100644
index eb286cb70f..0000000000
--- a/changelog.d/6455.feature
+++ /dev/null
@@ -1 +0,0 @@
-Include room states on invite events that are sent to application services. Contributed by @Sorunome.
diff --git a/changelog.d/7613.feature b/changelog.d/7613.feature
deleted file mode 100644
index b671dc2fcc..0000000000
--- a/changelog.d/7613.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel.
diff --git a/changelog.d/7798.feature b/changelog.d/7798.feature
deleted file mode 100644
index 56ffaf0d4a..0000000000
--- a/changelog.d/7798.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add experimental support for running multiple federation sender processes.
diff --git a/changelog.d/7802.misc b/changelog.d/7802.misc
deleted file mode 100644
index d81f8875c5..0000000000
--- a/changelog.d/7802.misc
+++ /dev/null
@@ -1 +0,0 @@
- Switch from simplejson to the standard library json.
diff --git a/changelog.d/7813.misc b/changelog.d/7813.misc
deleted file mode 100644
index f3005cfd27..0000000000
--- a/changelog.d/7813.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add type hints to the http server code and remove an unused parameter.
diff --git a/changelog.d/7815.bugfix b/changelog.d/7815.bugfix
deleted file mode 100644
index 3e7c7d412e..0000000000
--- a/changelog.d/7815.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix detection of out of sync remote device lists when receiving events from remote users.
diff --git a/changelog.d/7817.bugfix b/changelog.d/7817.bugfix
deleted file mode 100644
index 1c001070d5..0000000000
--- a/changelog.d/7817.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix bug where Synapse fails to process an incoming event over federation if the server is missing too much of the event's auth chain.
diff --git a/changelog.d/7820.misc b/changelog.d/7820.misc
deleted file mode 100644
index b77b5672e3..0000000000
--- a/changelog.d/7820.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add type hints to synapse.api.errors module.
diff --git a/changelog.d/7822.bugfix b/changelog.d/7822.bugfix
deleted file mode 100644
index faf249a678..0000000000
--- a/changelog.d/7822.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug causing Synapse to misinterpret the value `off` for `encryption_enabled_by_default_for_room_type` in its configuration file(s) if that value isn't surrounded by quotes. This bug was introduced in v1.16.0.
diff --git a/changelog.d/7827.feature b/changelog.d/7827.feature
deleted file mode 100644
index 0fd116e198..0000000000
--- a/changelog.d/7827.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add the option to validate the `iss` and `aud` claims for JWT logins.
diff --git a/changelog.d/7829.bugfix b/changelog.d/7829.bugfix
deleted file mode 100644
index dcbf385de6..0000000000
--- a/changelog.d/7829.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails.
diff --git a/changelog.d/7830.feature b/changelog.d/7830.feature
deleted file mode 100644
index b4f614084d..0000000000
--- a/changelog.d/7830.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add support for handling registration requests across multiple client reader workers.
diff --git a/changelog.d/7836.misc b/changelog.d/7836.misc
deleted file mode 100644
index a3a97c7590..0000000000
--- a/changelog.d/7836.misc
+++ /dev/null
@@ -1 +0,0 @@
-Ensure that calls to `json.dumps` are compatible with the standard library json.
diff --git a/changelog.d/7839.docker b/changelog.d/7839.docker
deleted file mode 100644
index cdf3c9631c..0000000000
--- a/changelog.d/7839.docker
+++ /dev/null
@@ -1 +0,0 @@
-Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196.
diff --git a/changelog.d/7842.feature b/changelog.d/7842.feature
deleted file mode 100644
index 727deb01c9..0000000000
--- a/changelog.d/7842.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH.
diff --git a/changelog.d/7844.bugfix b/changelog.d/7844.bugfix
deleted file mode 100644
index ad296f1b3c..0000000000
--- a/changelog.d/7844.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.
diff --git a/changelog.d/7846.feature b/changelog.d/7846.feature
deleted file mode 100644
index 997376fe42..0000000000
--- a/changelog.d/7846.feature
+++ /dev/null
@@ -1 +0,0 @@
-Allow email subjects to be customised through Synapse's configuration.
diff --git a/changelog.d/7847.feature b/changelog.d/7847.feature
deleted file mode 100644
index 4b9a8d8569..0000000000
--- a/changelog.d/7847.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add the ability to re-activate an account from the admin API.
diff --git a/changelog.d/7848.misc b/changelog.d/7848.misc
deleted file mode 100644
index d9db1d8357..0000000000
--- a/changelog.d/7848.misc
+++ /dev/null
@@ -1 +0,0 @@
-Remove redundant `retry_on_integrity_error` wrapper for event persistence code.
diff --git a/changelog.d/7849.misc b/changelog.d/7849.misc
deleted file mode 100644
index e3296418c1..0000000000
--- a/changelog.d/7849.misc
+++ /dev/null
@@ -1 +0,0 @@
-Consistently use `db_to_json` to convert from database values to JSON objects.
diff --git a/changelog.d/7850.bugfix b/changelog.d/7850.bugfix
deleted file mode 100644
index 5f19a89043..0000000000
--- a/changelog.d/7850.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix "AttributeError: 'str' object has no attribute 'get'" error message when applying per-room message retention policies. The bug was introduced in Synapse 1.7.0.
diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc
deleted file mode 100644
index e5cf540edf..0000000000
--- a/changelog.d/7851.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert E2E keys and room keys handlers to async/await.
diff --git a/changelog.d/7853.misc b/changelog.d/7853.misc
deleted file mode 100644
index b4f614084d..0000000000
--- a/changelog.d/7853.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add support for handling registration requests across multiple client reader workers.
diff --git a/changelog.d/7854.bugfix b/changelog.d/7854.bugfix
deleted file mode 100644
index b11f9dedfe..0000000000
--- a/changelog.d/7854.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation.
diff --git a/changelog.d/7855.feature b/changelog.d/7855.feature
deleted file mode 100644
index 2b6a9f0e71..0000000000
--- a/changelog.d/7855.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add experimental support for running multiple pusher workers.
diff --git a/changelog.d/7856.misc b/changelog.d/7856.misc
deleted file mode 100644
index 7d99fb67be..0000000000
--- a/changelog.d/7856.misc
+++ /dev/null
@@ -1 +0,0 @@
-Small performance improvement in typing processing.
diff --git a/changelog.d/7858.misc b/changelog.d/7858.misc
deleted file mode 100644
index 8f0fc2de74..0000000000
--- a/changelog.d/7858.misc
+++ /dev/null
@@ -1 +0,0 @@
-The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100.
diff --git a/changelog.d/7859.bugfix b/changelog.d/7859.bugfix
deleted file mode 100644
index 19cff4b061..0000000000
--- a/changelog.d/7859.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug which allowed empty rooms to be rejoined over federation.
diff --git a/changelog.d/7860.misc b/changelog.d/7860.misc
deleted file mode 100644
index fdd48b955c..0000000000
--- a/changelog.d/7860.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert _base, profile, and _receipts handlers to async/await.
diff --git a/changelog.d/7861.misc b/changelog.d/7861.misc
deleted file mode 100644
index ada616c62f..0000000000
--- a/changelog.d/7861.misc
+++ /dev/null
@@ -1 +0,0 @@
-Optimise queueing of inbound replication commands.
diff --git a/changelog.d/7866.bugfix b/changelog.d/7866.bugfix
deleted file mode 100644
index 6b5c3c4eca..0000000000
--- a/changelog.d/7866.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers.
diff --git a/changelog.d/7868.misc b/changelog.d/7868.misc
deleted file mode 100644
index eadef5e4c2..0000000000
--- a/changelog.d/7868.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert synapse.app and federation client to async/await.
diff --git a/changelog.d/7869.feature b/changelog.d/7869.feature
deleted file mode 100644
index 1982049a52..0000000000
--- a/changelog.d/7869.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add experimental support for moving typing off master.
diff --git a/changelog.d/7870.misc b/changelog.d/7870.misc
deleted file mode 100644
index 27cce2f2f9..0000000000
--- a/changelog.d/7870.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add some type annotations to `HomeServer` and `BaseHandler`.
diff --git a/changelog.d/7871.misc b/changelog.d/7871.misc
deleted file mode 100644
index 4d398a9f3a..0000000000
--- a/changelog.d/7871.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert device handler to async/await.
diff --git a/changelog.d/7872.bugfix b/changelog.d/7872.bugfix
deleted file mode 100644
index b21f8e1f14..0000000000
--- a/changelog.d/7872.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a long standing bug where the tracing of async functions with opentracing was broken.
diff --git a/changelog.d/7874.misc b/changelog.d/7874.misc
deleted file mode 100644
index f75c8d1843..0000000000
--- a/changelog.d/7874.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert the federation agent and related code to async/await.
diff --git a/changelog.d/7876.bugfix b/changelog.d/7876.bugfix
new file mode 100644
index 0000000000..4ba2fadd58
--- /dev/null
+++ b/changelog.d/7876.bugfix
@@ -0,0 +1 @@
+Fix an `AssertionError` exception introduced in v1.18.0rc1.
diff --git a/changelog.d/7876.misc b/changelog.d/7876.misc
new file mode 100644
index 0000000000..5c78a158cd
--- /dev/null
+++ b/changelog.d/7876.misc
@@ -0,0 +1 @@
+Further optimise queueing of inbound replication commands.
diff --git a/changelog.d/7877.misc b/changelog.d/7877.misc
deleted file mode 100644
index a62aa0329c..0000000000
--- a/changelog.d/7877.misc
+++ /dev/null
@@ -1 +0,0 @@
-Clean up `PreserveLoggingContext`.
diff --git a/changelog.d/7878.removal b/changelog.d/7878.removal
deleted file mode 100644
index d5a4066624..0000000000
--- a/changelog.d/7878.removal
+++ /dev/null
@@ -1 +0,0 @@
-Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric.
diff --git a/changelog.d/7879.feature b/changelog.d/7879.feature
deleted file mode 100644
index c89655f000..0000000000
--- a/changelog.d/7879.feature
+++ /dev/null
@@ -1 +0,0 @@
-Report CPU metrics to prometheus for time spent processing replication commands.
diff --git a/changelog.d/7880.bugfix b/changelog.d/7880.bugfix
deleted file mode 100644
index 356add0996..0000000000
--- a/changelog.d/7880.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix "TypeError in `synapse.notifier`" exceptions.
diff --git a/changelog.d/7881.misc b/changelog.d/7881.misc
deleted file mode 100644
index 6799117099..0000000000
--- a/changelog.d/7881.misc
+++ /dev/null
@@ -1 +0,0 @@
-Change "unknown room version" logging from 'error' to 'warning'.
diff --git a/changelog.d/7882.misc b/changelog.d/7882.misc
deleted file mode 100644
index 9002749335..0000000000
--- a/changelog.d/7882.misc
+++ /dev/null
@@ -1 +0,0 @@
-Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`.
diff --git a/changelog.d/7884.misc b/changelog.d/7884.misc
deleted file mode 100644
index 36c7d4de67..0000000000
--- a/changelog.d/7884.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert the message handler to async/await.
diff --git a/changelog.d/7885.doc b/changelog.d/7885.doc
deleted file mode 100644
index cbe9de4082..0000000000
--- a/changelog.d/7885.doc
+++ /dev/null
@@ -1 +0,0 @@
-Provide instructions on using `register_new_matrix_user` via docker.
diff --git a/changelog.d/7888.misc b/changelog.d/7888.misc
deleted file mode 100644
index 5328d2dcca..0000000000
--- a/changelog.d/7888.misc
+++ /dev/null
@@ -1 +0,0 @@
-Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim.
diff --git a/changelog.d/7889.doc b/changelog.d/7889.doc
deleted file mode 100644
index d91f62fd39..0000000000
--- a/changelog.d/7889.doc
+++ /dev/null
@@ -1 +0,0 @@
-Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation.
\ No newline at end of file
diff --git a/changelog.d/7890.misc b/changelog.d/7890.misc
deleted file mode 100644
index 8c127084bc..0000000000
--- a/changelog.d/7890.misc
+++ /dev/null
@@ -1 +0,0 @@
-Fix typo in generated config file. Contributed by @ThiefMaster.
diff --git a/changelog.d/7892.misc b/changelog.d/7892.misc
deleted file mode 100644
index ef4cfa04fd..0000000000
--- a/changelog.d/7892.misc
+++ /dev/null
@@ -1 +0,0 @@
-Import ABC from `collections.abc` for Python 3.10 compatibility.
diff --git a/changelog.d/7895.bugfix b/changelog.d/7895.bugfix
deleted file mode 100644
index 1ae7f8ca7c..0000000000
--- a/changelog.d/7895.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix deprecation warning due to invalid escape sequences.
\ No newline at end of file
diff --git a/changelog.d/7897.misc b/changelog.d/7897.misc
deleted file mode 100644
index 77772533fd..0000000000
--- a/changelog.d/7897.misc
+++ /dev/null
@@ -1,2 +0,0 @@
-Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
-and `get_previous_frame` from `synapse.logging.utils` module.
\ No newline at end of file
diff --git a/changelog.d/7908.feature b/changelog.d/7908.feature
deleted file mode 100644
index 4b9a8d8569..0000000000
--- a/changelog.d/7908.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add the ability to re-activate an account from the admin API.
diff --git a/changelog.d/7912.misc b/changelog.d/7912.misc
deleted file mode 100644
index d619590070..0000000000
--- a/changelog.d/7912.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert `RoomListHandler` to async/await.
diff --git a/changelog.d/7914.misc b/changelog.d/7914.misc
deleted file mode 100644
index 710553249c..0000000000
--- a/changelog.d/7914.misc
+++ /dev/null
@@ -1 +0,0 @@
-Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI.
diff --git a/changelog.d/7919.misc b/changelog.d/7919.misc
deleted file mode 100644
index addaa35183..0000000000
--- a/changelog.d/7919.misc
+++ /dev/null
@@ -1 +0,0 @@
-Use Element CSS and logo in notification emails when app name is Element.
diff --git a/changelog.d/7927.misc b/changelog.d/7927.misc
deleted file mode 100644
index 3b864da03d..0000000000
--- a/changelog.d/7927.misc
+++ /dev/null
@@ -1 +0,0 @@
-Optimisation to /sync handling: skip serializing the response if the client has already disconnected.
diff --git a/changelog.d/7928.misc b/changelog.d/7928.misc
deleted file mode 100644
index 5f3aa5de0a..0000000000
--- a/changelog.d/7928.misc
+++ /dev/null
@@ -1 +0,0 @@
-When a client disconnects, don't log it as 'Error processing request'.
diff --git a/changelog.d/7929.misc b/changelog.d/7929.misc
deleted file mode 100644
index d72856fe03..0000000000
--- a/changelog.d/7929.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add debugging to `/sync` response generation (disabled by default).
diff --git a/changelog.d/7930.feature b/changelog.d/7930.feature
deleted file mode 100644
index a27e4812da..0000000000
--- a/changelog.d/7930.feature
+++ /dev/null
@@ -1 +0,0 @@
-Abort federation requests where the client disconnects before the ratelimiter expires.
diff --git a/changelog.d/7931.feature b/changelog.d/7931.feature
deleted file mode 100644
index 30eb33048b..0000000000
--- a/changelog.d/7931.feature
+++ /dev/null
@@ -1 +0,0 @@
-Cache responses to `/_matrix/federation/v1/state_ids` to reduce duplicated work.
diff --git a/changelog.d/7933.doc b/changelog.d/7933.doc
deleted file mode 100644
index 7022fd578b..0000000000
--- a/changelog.d/7933.doc
+++ /dev/null
@@ -1 +0,0 @@
-Reorder database paragraphs to promote postgres over sqlite.
diff --git a/changelog.d/7934.doc b/changelog.d/7934.doc
deleted file mode 100644
index 992d5358a7..0000000000
--- a/changelog.d/7934.doc
+++ /dev/null
@@ -1 +0,0 @@
-Update the dates of ACME v1's end of life in [`ACME.md`](https://github.com/matrix-org/synapse/blob/master/docs/ACME.md).
diff --git a/changelog.d/7935.misc b/changelog.d/7935.misc
deleted file mode 100644
index 3771f99bf2..0000000000
--- a/changelog.d/7935.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert the auth providers to be async/await.
diff --git a/changelog.d/7939.misc b/changelog.d/7939.misc
deleted file mode 100644
index 798833b3af..0000000000
--- a/changelog.d/7939.misc
+++ /dev/null
@@ -1 +0,0 @@
-Convert presence handler helpers to async/await.
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 8592dee179..900513499d 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.17.0"
+__version__ = "1.18.0rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 40dc62ef6c..b53e8451e5 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -127,8 +127,10 @@ class Auth(object):
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
- member = yield self.state.get_current_state(
- room_id=room_id, event_type=EventTypes.Member, state_key=user_id
+ member = yield defer.ensureDeferred(
+ self.state.get_current_state(
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
+ )
)
membership = member.membership if member else None
@@ -665,8 +667,10 @@ class Auth(object):
)
return member_event.membership, member_event.event_id
except AuthError:
- visibility = yield self.state.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
+ visibility = yield defer.ensureDeferred(
+ self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
+ )
)
if (
visibility
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c1b76d827b..ec0dbddb8c 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -87,7 +87,6 @@ from synapse.replication.tcp.streams import (
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
- TypingStream,
)
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
@@ -644,7 +643,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
super(GenericWorkerReplicationHandler, self).__init__(hs)
self.store = hs.get_datastore()
- self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
self.notifier = hs.get_notifier()
@@ -681,11 +679,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
- elif stream_name == TypingStream.NAME:
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows]
- )
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 92aadfe7ef..0bb216419a 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -106,8 +106,8 @@ class EventBuilder(object):
Deferred[FrozenEvent]
"""
- state_ids = yield self._state.get_current_state_ids(
- self.room_id, prev_event_ids
+ state_ids = yield defer.ensureDeferred(
+ self._state.get_current_state_ids(self.room_id, prev_event_ids)
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 64282abc60..fb57e42287 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -348,7 +348,9 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
- domains = yield self.state.get_current_hosts_in_room(room_id)
+ domains = yield defer.ensureDeferred(
+ self.state.get_current_hosts_in_room(room_id)
+ )
domains = [
d
for d in domains
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f3c0aeceb6..506bb2b275 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -72,7 +72,7 @@ class AdminHandler(BaseHandler):
writer (ExfiltrationWriter)
Returns:
- defer.Deferred: Resolves when all data for a user has been written.
+ Resolves when all data for a user has been written.
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 361dd64cd2..84169c1022 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,10 +16,11 @@
# limitations under the License.
import logging
+from typing import Dict, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json, json
-from signedjson.key import decode_verify_key_bytes
+from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
@@ -265,7 +266,9 @@ class E2eKeysHandler(object):
return ret
- async def get_cross_signing_keys_from_cache(self, query, from_user_id):
+ async def get_cross_signing_keys_from_cache(
+ self, query, from_user_id
+ ) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
@@ -277,8 +280,7 @@ class E2eKeysHandler(object):
can see.
Returns:
- defer.Deferred[dict[str, dict[str, dict]]]: map from
- (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
+ A map from (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
"""
master_keys = {}
self_signing_keys = {}
@@ -312,16 +314,17 @@ class E2eKeysHandler(object):
}
@trace
- async def query_local_devices(self, query):
+ async def query_local_devices(
+ self, query: Dict[str, Optional[List[str]]]
+ ) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
Args:
- query (dict[string, list[string]|None): map from user_id to a list
+ query: map from user_id to a list
of devices to query (None for all devices)
Returns:
- defer.Deferred: (resolves to dict[string, dict[string, dict]]):
- map from user_id -> device_id -> device details
+ A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = []
@@ -1004,7 +1007,7 @@ class E2eKeysHandler(object):
async def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str,
- ):
+ ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1015,8 +1018,7 @@ class E2eKeysHandler(object):
desired_key_type: The type of key to receive. One of "master", "self_signing"
Returns:
- Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
- of the retrieved key content, the key's ID and the matching VerifyKey.
+ A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
try:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 71ac5dca99..f5f683bfd4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1394,7 +1394,7 @@ class FederationHandler(BaseHandler):
# it's just a best-effort thing at this point. We do want to do
# them roughly in order, though, otherwise we'll end up making
# lots of requests for missing prev_events which we do actually
- # have. Hence we fire off the deferred, but don't wait for it.
+ # have. Hence we fire off the background task, but don't wait for it.
run_in_background(self._handle_queued_pdus, room_queue)
@@ -1887,9 +1887,6 @@ class FederationHandler(BaseHandler):
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
- # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
- # hack around with a try/finally instead.
- success = False
try:
if (
not event.internal_metadata.is_outlier()
@@ -1903,12 +1900,11 @@ class FederationHandler(BaseHandler):
await self.persist_events_and_notify(
[(event, context)], backfilled=backfilled
)
- success = True
- finally:
- if not success:
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
+ except Exception:
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
+ )
+ raise
return context
@@ -2994,7 +2990,9 @@ class FederationHandler(BaseHandler):
else:
user_joined_room(self.distributor, user, room_id)
- async def get_room_complexity(self, remote_room_hosts, room_id):
+ async def get_room_complexity(
+ self, remote_room_hosts: List[str], room_id: str
+ ) -> Optional[dict]:
"""
Fetch the complexity of a remote room over federation.
@@ -3003,7 +3001,7 @@ class FederationHandler(BaseHandler):
room_id (str): The room ID to ask about.
Returns:
- Deferred[dict] or Deferred[None]: Dict contains the complexity
+ Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
"""
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 701233ebb4..0bd2c3e37a 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -19,6 +19,7 @@
import logging
import urllib.parse
+from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
@@ -36,6 +37,7 @@ from synapse.api.errors import (
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
+from synapse.types import JsonDict, Requester
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -59,23 +61,23 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client()
self.hs = hs
- async def threepid_from_creds(self, id_server, creds):
+ async def threepid_from_creds(
+ self, id_server: str, creds: Dict[str, str]
+ ) -> Optional[JsonDict]:
"""
Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server
Args:
- id_server (str): The identity server to validate 3PIDs against. Must be a
+ id_server: The identity server to validate 3PIDs against. Must be a
complete URL including the protocol (http(s)://)
-
- creds (dict[str, str]): Dictionary containing the following keys:
+ creds: Dictionary containing the following keys:
* client_secret|clientSecret: A unique secret str provided by the client
* sid: The ID of the validation session
Returns:
- Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
- the /getValidated3pid endpoint of the Identity Service API, or None if the
- threepid was not found
+ A dictionary consisting of response params to the /getValidated3pid
+ endpoint of the Identity Service API, or None if the threepid was not found
"""
client_secret = creds.get("client_secret") or creds.get("clientSecret")
if not client_secret:
@@ -119,26 +121,27 @@ class IdentityHandler(BaseHandler):
return None
async def bind_threepid(
- self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
- ):
+ self,
+ client_secret: str,
+ sid: str,
+ mxid: str,
+ id_server: str,
+ id_access_token: Optional[str] = None,
+ use_v2: bool = True,
+ ) -> JsonDict:
"""Bind a 3PID to an identity server
Args:
- client_secret (str): A unique secret provided by the client
-
- sid (str): The ID of the validation session
-
- mxid (str): The MXID to bind the 3PID to
-
- id_server (str): The domain of the identity server to query
-
- id_access_token (str): The access token to authenticate to the identity
+ client_secret: A unique secret provided by the client
+ sid: The ID of the validation session
+ mxid: The MXID to bind the 3PID to
+ id_server: The domain of the identity server to query
+ id_access_token: The access token to authenticate to the identity
server with, if necessary. Required if use_v2 is true
-
- use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True
+ use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
Returns:
- Deferred[dict]: The response from the identity server
+ The response from the identity server
"""
logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
@@ -151,7 +154,7 @@ class IdentityHandler(BaseHandler):
bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
if use_v2:
bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
- headers["Authorization"] = create_id_access_token_header(id_access_token)
+ headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore
else:
bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
@@ -187,20 +190,20 @@ class IdentityHandler(BaseHandler):
)
return res
- async def try_unbind_threepid(self, mxid, threepid):
+ async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on
Args:
- mxid (str): Matrix user ID of binding to be removed
- threepid (dict): Dict with medium & address of binding to be
+ mxid: Matrix user ID of binding to be removed
+ threepid: Dict with medium & address of binding to be
removed, and an optional id_server.
Raises:
SynapseError: If we failed to contact the identity server
Returns:
- Deferred[bool]: True on success, otherwise False if the identity
+ True on success, otherwise False if the identity
server doesn't support unbinding (or no identity server found to
contact).
"""
@@ -223,19 +226,21 @@ class IdentityHandler(BaseHandler):
return changed
- async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
+ async def try_unbind_threepid_with_id_server(
+ self, mxid: str, threepid: dict, id_server: str
+ ) -> bool:
"""Removes a binding from an identity server
Args:
- mxid (str): Matrix user ID of binding to be removed
- threepid (dict): Dict with medium & address of binding to be removed
- id_server (str): Identity server to unbind from
+ mxid: Matrix user ID of binding to be removed
+ threepid: Dict with medium & address of binding to be removed
+ id_server: Identity server to unbind from
Raises:
SynapseError: If we failed to contact the identity server
Returns:
- Deferred[bool]: True on success, otherwise False if the identity
+ True on success, otherwise False if the identity
server doesn't support unbinding
"""
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
@@ -287,23 +292,23 @@ class IdentityHandler(BaseHandler):
async def send_threepid_validation(
self,
- email_address,
- client_secret,
- send_attempt,
- send_email_func,
- next_link=None,
- ):
+ email_address: str,
+ client_secret: str,
+ send_attempt: int,
+ send_email_func: Callable[[str, str, str, str], Awaitable],
+ next_link: Optional[str] = None,
+ ) -> str:
"""Send a threepid validation email for password reset or
registration purposes
Args:
- email_address (str): The user's email address
- client_secret (str): The provided client secret
- send_attempt (int): Which send attempt this is
- send_email_func (func): A function that takes an email address, token,
- client_secret and session_id, sends an email
- and returns a Deferred.
- next_link (str|None): The URL to redirect the user to after validation
+ email_address: The user's email address
+ client_secret: The provided client secret
+ send_attempt: Which send attempt this is
+ send_email_func: A function that takes an email address, token,
+ client_secret and session_id, sends an email
+ and returns an Awaitable.
+ next_link: The URL to redirect the user to after validation
Returns:
The new session_id upon success
@@ -372,17 +377,22 @@ class IdentityHandler(BaseHandler):
return session_id
async def requestEmailToken(
- self, id_server, email, client_secret, send_attempt, next_link=None
- ):
+ self,
+ id_server: str,
+ email: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str] = None,
+ ) -> JsonDict:
"""
Request an external server send an email on our behalf for the purposes of threepid
validation.
Args:
- id_server (str): The identity server to proxy to
- email (str): The email to send the message to
- client_secret (str): The unique client_secret sends by the user
- send_attempt (int): Which attempt this is
+ id_server: The identity server to proxy to
+ email: The email to send the message to
+ client_secret: The unique client_secret sends by the user
+ send_attempt: Which attempt this is
next_link: A link to redirect the user to once they submit the token
Returns:
@@ -419,22 +429,22 @@ class IdentityHandler(BaseHandler):
async def requestMsisdnToken(
self,
- id_server,
- country,
- phone_number,
- client_secret,
- send_attempt,
- next_link=None,
- ):
+ id_server: str,
+ country: str,
+ phone_number: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str] = None,
+ ) -> JsonDict:
"""
Request an external server send an SMS message on our behalf for the purposes of
threepid validation.
Args:
- id_server (str): The identity server to proxy to
- country (str): The country code of the phone number
- phone_number (str): The number to send the message to
- client_secret (str): The unique client_secret sends by the user
- send_attempt (int): Which attempt this is
+ id_server: The identity server to proxy to
+ country: The country code of the phone number
+ phone_number: The number to send the message to
+ client_secret: The unique client_secret sends by the user
+ send_attempt: Which attempt this is
next_link: A link to redirect the user to once they submit the token
Returns:
@@ -480,17 +490,18 @@ class IdentityHandler(BaseHandler):
)
return data
- async def validate_threepid_session(self, client_secret, sid):
+ async def validate_threepid_session(
+ self, client_secret: str, sid: str
+ ) -> Optional[JsonDict]:
"""Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally.
Args:
- client_secret (str): A secret provided by the client
-
- sid (str): The ID of the session
+ client_secret: A secret provided by the client
+ sid: The ID of the session
Returns:
- Dict[str, str|int] if validation was successful, otherwise None
+ The json response if validation was successful, otherwise None
"""
# XXX: We shouldn't need to keep wrapping and unwrapping this value
threepid_creds = {"client_secret": client_secret, "sid": sid}
@@ -523,23 +534,22 @@ class IdentityHandler(BaseHandler):
return validation_session
- async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ async def proxy_msisdn_submit_token(
+ self, id_server: str, client_secret: str, sid: str, token: str
+ ) -> JsonDict:
"""Proxy a POST submitToken request to an identity server for verification purposes
Args:
- id_server (str): The identity server URL to contact
-
- client_secret (str): Secret provided by the client
-
- sid (str): The ID of the session
-
- token (str): The verification token
+ id_server: The identity server URL to contact
+ client_secret: Secret provided by the client
+ sid: The ID of the session
+ token: The verification token
Raises:
SynapseError: If we failed to contact the identity server
Returns:
- Deferred[dict]: The response dict from the identity server
+ The response dict from the identity server
"""
body = {"client_secret": client_secret, "sid": sid, "token": token}
@@ -554,19 +564,25 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server")
- async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+ async def lookup_3pid(
+ self,
+ id_server: str,
+ medium: str,
+ address: str,
+ id_access_token: Optional[str] = None,
+ ) -> Optional[str]:
"""Looks up a 3pid in the passed identity server.
Args:
- id_server (str): The server name (including port, if required)
+ id_server: The server name (including port, if required)
of the identity server to use.
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
- id_access_token (str|None): The access token to authenticate to the identity
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
+ id_access_token: The access token to authenticate to the identity
server with
Returns:
- str|None: the matrix ID of the 3pid, or None if it is not recognized.
+ the matrix ID of the 3pid, or None if it is not recognized.
"""
if id_access_token is not None:
try:
@@ -591,17 +607,19 @@ class IdentityHandler(BaseHandler):
return await self._lookup_3pid_v1(id_server, medium, address)
- async def _lookup_3pid_v1(self, id_server, medium, address):
+ async def _lookup_3pid_v1(
+ self, id_server: str, medium: str, address: str
+ ) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
- id_server (str): The server name (including port, if required)
+ id_server: The server name (including port, if required)
of the identity server to use.
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
Returns:
- str: the matrix ID of the 3pid, or None if it is not recognized.
+ the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = await self.blacklisting_http_client.get_json(
@@ -621,18 +639,20 @@ class IdentityHandler(BaseHandler):
return None
- async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
+ async def _lookup_3pid_v2(
+ self, id_server: str, id_access_token: str, medium: str, address: str
+ ) -> Optional[str]:
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
- id_server (str): The server name (including port, if required)
+ id_server: The server name (including port, if required)
of the identity server to use.
- id_access_token (str): The access token to authenticate to the identity server with
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
+ id_access_token: The access token to authenticate to the identity server with
+ medium: The type of the third party identifier (e.g. "email").
+ address: The third party identifier (e.g. "foo@example.com").
Returns:
- Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
+ the matrix ID of the 3pid, or None if it is not recognised.
"""
# Check what hashing details are supported by this identity server
try:
@@ -757,49 +777,48 @@ class IdentityHandler(BaseHandler):
async def ask_id_server_for_third_party_invite(
self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter_user_id,
- room_alias,
- room_avatar_url,
- room_join_rules,
- room_name,
- inviter_display_name,
- inviter_avatar_url,
- id_access_token=None,
- ):
+ requester: Requester,
+ id_server: str,
+ medium: str,
+ address: str,
+ room_id: str,
+ inviter_user_id: str,
+ room_alias: str,
+ room_avatar_url: str,
+ room_join_rules: str,
+ room_name: str,
+ inviter_display_name: str,
+ inviter_avatar_url: str,
+ id_access_token: Optional[str] = None,
+ ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
"""
Asks an identity server for a third party invite.
Args:
- requester (Requester)
- id_server (str): hostname + optional port for the identity server.
- medium (str): The literal string "email".
- address (str): The third party address being invited.
- room_id (str): The ID of the room to which the user is invited.
- inviter_user_id (str): The user ID of the inviter.
- room_alias (str): An alias for the room, for cosmetic notifications.
- room_avatar_url (str): The URL of the room's avatar, for cosmetic
+ requester
+ id_server: hostname + optional port for the identity server.
+ medium: The literal string "email".
+ address: The third party address being invited.
+ room_id: The ID of the room to which the user is invited.
+ inviter_user_id: The user ID of the inviter.
+ room_alias: An alias for the room, for cosmetic notifications.
+ room_avatar_url: The URL of the room's avatar, for cosmetic
notifications.
- room_join_rules (str): The join rules of the email (e.g. "public").
- room_name (str): The m.room.name of the room.
- inviter_display_name (str): The current display name of the
+ room_join_rules: The join rules of the email (e.g. "public").
+ room_name: The m.room.name of the room.
+ inviter_display_name: The current display name of the
inviter.
- inviter_avatar_url (str): The URL of the inviter's avatar.
+ inviter_avatar_url: The URL of the inviter's avatar.
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
- A deferred tuple containing:
- token (str): The token which must be signed to prove authenticity.
+ A tuple containing:
+ token: The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
- display_name (str): A user-friendly name to represent the invited
- user.
+ display_name: A user-friendly name to represent the invited user.
"""
invite_config = {
"medium": medium,
@@ -896,15 +915,15 @@ class IdentityHandler(BaseHandler):
return token, public_keys, fallback_public_key, display_name
-def create_id_access_token_header(id_access_token):
+def create_id_access_token_header(id_access_token: str) -> List[str]:
"""Create an Authorization header for passing to SimpleHttpClient as the header value
of an HTTP request.
Args:
- id_access_token (str): An identity server access token.
+ id_access_token: An identity server access token.
Returns:
- list[str]: The ascii-encoded bearer token encased in a list.
+ The ascii-encoded bearer token encased in a list.
"""
# Prefix with Bearer
bearer_token = "Bearer %s" % id_access_token
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4b9b80a36d..a28068244d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -859,9 +859,6 @@ class EventCreationHandler(object):
await self.action_generator.handle_push_actions_for_event(event, context)
- # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
- # hack around with a try/finally instead.
- success = False
try:
# If we're a worker we need to hit out to the master.
if not self._is_event_writer:
@@ -877,22 +874,20 @@ class EventCreationHandler(object):
)
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
- success = True
return stream_id
stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- success = True
return stream_id
- finally:
- if not success:
- # Ensure that we actually remove the entries in the push actions
- # staging area, if we calculated them.
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
+ except Exception:
+ # Ensure that we actually remove the entries in the push actions
+ # staging area, if we calculated them.
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
+ )
+ raise
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 8e99c83d9d..b3a3bb8c3f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the
# room.
- user_ids = await self.state.get_current_users_in_room(room_id)
- user_ids = list(filter(self.is_mine_id, user_ids))
+ users = await self.state.get_current_users_in_room(room_id)
+ user_ids = list(filter(self.is_mine_id, users))
states_d = await self.current_state_for_users(user_ids)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index fb37d371ad..0c5b99234d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -119,7 +119,7 @@ class RoomCreationHandler(BaseHandler):
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
- ):
+ ) -> str:
"""Replace a room with a new room with a different version
Args:
@@ -128,7 +128,7 @@ class RoomCreationHandler(BaseHandler):
new_version: the new room version to use
Returns:
- Deferred[unicode]: the new room id
+ the new room id
"""
await self.ratelimit(requester)
@@ -239,7 +239,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
- ):
+ ) -> None:
"""Send updated power levels in both rooms after an upgrade
Args:
@@ -247,9 +247,6 @@ class RoomCreationHandler(BaseHandler):
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
-
- Returns:
- Deferred
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
@@ -322,7 +319,7 @@ class RoomCreationHandler(BaseHandler):
new_room_id: str,
new_room_version: RoomVersion,
tombstone_event_id: str,
- ):
+ ) -> None:
"""Populate a new room based on an old room
Args:
@@ -332,8 +329,6 @@ class RoomCreationHandler(BaseHandler):
created with _gemerate_room_id())
new_room_version: the new room version to use
tombstone_event_id: the ID of the tombstone event in the old room.
- Returns:
- Deferred
"""
user_id = requester.user.to_string()
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 4d40d3ac9c..9b312a1558 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,6 +15,7 @@
import itertools
import logging
+from typing import Iterable
from unpaddedbase64 import decode_base64, encode_base64
@@ -37,7 +38,7 @@ class SearchHandler(BaseHandler):
self.state_store = self.storage.state
self.auth = hs.get_auth()
- async def get_old_rooms_from_upgraded_room(self, room_id):
+ async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
@@ -48,10 +49,10 @@ class SearchHandler(BaseHandler):
The full list of all found rooms in then returned.
Args:
- room_id (str): id of the room to search through.
+ room_id: id of the room to search through.
Returns:
- Deferred[iterable[str]]: predecessor room ids
+ Predecessor room ids
"""
historical_room_ids = []
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 99dd4ee948..c308647700 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -424,10 +424,6 @@ class SyncHandler(object):
potential_recents: Optional[List[EventBase]] = None,
newly_joined_room: bool = False,
) -> TimelineBatch:
- """
- Returns:
- a Deferred TimelineBatch
- """
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = (
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 8e003689c4..d4f9ad6e67 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -442,21 +442,6 @@ class StaticResource(File):
return super().render_GET(request)
-def _options_handler(request):
- """Request handler for OPTIONS requests
-
- This is a request handler suitable for return from
- _get_handler_for_request. It returns a 200 and an empty body.
-
- Args:
- request (twisted.web.http.Request):
-
- Returns:
- Tuple[int, dict]: http code, response body.
- """
- return 200, {}
-
-
def _unrecognised_request_handler(request):
"""Request handler for unrecognised requests
@@ -490,11 +475,12 @@ class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request):
- code, response_json_object = _options_handler(request)
+ request.setResponseCode(204)
+ request.setHeader(b"Content-Length", b"0")
- return respond_with_json(
- request, code, response_json_object, send_cors=True, canonical_json=False,
- )
+ set_cors_headers(request)
+
+ return b""
def getChildWithDefault(self, path, request):
if request.method == b"OPTIONS":
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 2101517575..21dbd9f415 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -737,24 +737,14 @@ def trace(func=None, opname=None):
@wraps(func)
async def _trace_inner(*args, **kwargs):
- if opentracing is None:
+ with start_active_span(_opname):
return await func(*args, **kwargs)
- with start_active_span(_opname) as scope:
- try:
- return await func(*args, **kwargs)
- except Exception:
- scope.span.set_tag(tags.ERROR, True)
- raise
-
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args, **kwargs):
- if opentracing is None:
- return func(*args, **kwargs)
-
scope = start_active_span(_opname)
scope.__enter__()
@@ -767,7 +757,6 @@ def trace(func=None, opname=None):
return result
def err_back(result):
- scope.span.set_tag(tags.ERROR, True)
scope.__exit__(None, None, None)
return result
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index dc3ab00cbb..026854b4c7 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -116,6 +116,8 @@ class _LogContextScope(Scope):
if self._enter_logcontext:
self.logcontext.__enter__()
+ return self
+
def __exit__(self, type, value, traceback):
if type == twisted.internet.defer._DefGen_Return:
super(_LogContextScope, self).__exit__(None, None, None)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 43ffe6faf0..472ddf9f7d 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -304,7 +304,9 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(
+ context.get_current_state_ids()
+ )
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 4985e40b1f..fcf8ebf1e7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -24,6 +24,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams import TypingStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
@@ -104,6 +105,7 @@ class ReplicationDataHandler:
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
+ self._typing_handler = hs.get_typing_handler()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
@@ -127,6 +129,12 @@ class ReplicationDataHandler:
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
+ if stream_name == TypingStream.NAME:
+ self._typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows]
+ )
+
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1de590bba2..1c303f3a46 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -16,6 +16,7 @@
import logging
from typing import (
Any,
+ Awaitable,
Dict,
Iterable,
Iterator,
@@ -33,6 +34,7 @@ from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
@@ -152,7 +154,7 @@ class ReplicationCommandHandler:
# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.
- # the streams which are currently being processed by _unsafe_process_stream
+ # the streams which are currently being processed by _unsafe_process_queue
self._processing_streams = set() # type: Set[str]
# for each stream, a queue of commands that are awaiting processing, and the
@@ -185,7 +187,7 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
- async def _add_command_to_stream_queue(
+ def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@@ -199,33 +201,34 @@ class ReplicationCommandHandler:
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return
- # if we're already processing this stream, stick the new command in the
- # queue, and we're done.
+ queue.append((cmd, conn))
+
+ # if we're already processing this stream, there's nothing more to do:
+ # the new entry on the queue will get picked up in due course
if stream_name in self._processing_streams:
- queue.append((cmd, conn))
return
- # otherwise, process the new command.
+ # fire off a background process to start processing the queue.
+ run_as_background_process(
+ "process-replication-data", self._unsafe_process_queue, stream_name
+ )
- # arguably we should start off a new background process here, but nothing
- # will be too upset if we don't return for ages, so let's save the overhead
- # and use the existing logcontext.
+ async def _unsafe_process_queue(self, stream_name: str):
+ """Processes the command queue for the given stream, until it is empty
+
+ Does not check if there is already a thread processing the queue, hence "unsafe"
+ """
+ assert stream_name not in self._processing_streams
self._processing_streams.add(stream_name)
try:
- # might as well skip the queue for this one, since it must be empty
- assert not queue
- await self._process_command(cmd, conn, stream_name)
-
- # now process any other commands that have built up while we were
- # dealing with that one.
+ queue = self._command_queues_by_stream.get(stream_name)
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)
-
finally:
self._processing_streams.discard(stream_name)
@@ -299,7 +302,7 @@ class ReplicationCommandHandler:
"""
return self._streams_to_replicate
- async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
@@ -318,57 +321,73 @@ class ReplicationCommandHandler:
)
)
- async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+ def on_USER_SYNC(
+ self, conn: AbstractConnection, cmd: UserSyncCommand
+ ) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
if self._is_master:
- await self._presence_handler.update_external_syncs_row(
+ return self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
+ else:
+ return None
- async def on_CLEAR_USER_SYNC(
+ def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
- ):
+ ) -> Optional[Awaitable[None]]:
if self._is_master:
- await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+ return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+ else:
+ return None
- async def on_FEDERATION_ACK(
- self, conn: AbstractConnection, cmd: FederationAckCommand
- ):
+ def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
- async def on_REMOVE_PUSHER(
+ def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
- ):
+ ) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()
if self._is_master:
- await self._store.delete_pusher_by_app_id_pushkey_user_id(
- app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
- )
+ return self._handle_remove_pusher(cmd)
+ else:
+ return None
+
+ async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
+ await self._store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+ )
- self._notifier.on_new_replication_data()
+ self._notifier.on_new_replication_data()
- async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+ def on_USER_IP(
+ self, conn: AbstractConnection, cmd: UserIpCommand
+ ) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
if self._is_master:
- await self._store.insert_client_ip(
- cmd.user_id,
- cmd.access_token,
- cmd.ip,
- cmd.user_agent,
- cmd.device_id,
- cmd.last_seen,
- )
+ return self._handle_user_ip(cmd)
+ else:
+ return None
+
+ async def _handle_user_ip(self, cmd: UserIpCommand):
+ await self._store.insert_client_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
+ cmd.last_seen,
+ )
- if self._server_notices_sender:
- await self._server_notices_sender.on_user_ip(cmd.user_id)
+ assert self._server_notices_sender is not None
+ await self._server_notices_sender.on_user_ip(cmd.user_id)
- async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -382,7 +401,7 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
- await self._add_command_to_stream_queue(conn, cmd)
+ self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
@@ -459,14 +478,14 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
- await self._add_command_to_stream_queue(conn, cmd)
+ self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
@@ -526,9 +545,7 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
- async def on_REMOTE_SERVER_UP(
- self, conn: AbstractConnection, cmd: RemoteServerUpCommand
- ):
+ def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 23191e3218..0350923898 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc
import fcntl
import logging
import struct
+from inspect import isawaitable
from typing import TYPE_CHECKING, List
from prometheus_client import Counter
@@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
+ `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
+ if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
- self._logging_context = BackgroundProcessLoggingContext(
- "replication_command_handler-%s" % self.conn_id
- )
+ ctx_name = "replication-conn-%s" % self.conn_id
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name)
+ self._logging_context.request = ctx_name
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
- `self.command_handler.on_<COMMAND>` if it exists. This allows for
- protocol level handling of commands (e.g. PINGs), before delegating to
- the handler.
+ `self.command_handler.on_<COMMAND>` if it exists (which can optionally
+ return an Awaitable).
+
+ This allows for protocol level handling of commands (e.g. PINGs), before
+ delegating to the handler.
Args:
cmd: received command
@@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(self, cmd)
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
+
handled = True
if not handled:
@@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- async def on_PING(self, line):
+ def on_PING(self, line):
self.received_ping = True
- async def on_ERROR(self, cmd):
+ def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
- async def on_NAME(self, cmd):
+ def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams
self.replicate()
- async def on_SERVER(self, cmd):
+ def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index b5c533a607..f225e533de 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from inspect import isawaitable
from typing import TYPE_CHECKING
import txredisapi
@@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>, which should return an awaitable.
+ Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+ Awaitable).
Args:
cmd: received command
"""
- handled = False
-
- # First call any command handlers on this instance. These are for redis
- # specific handling.
- cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(cmd)
- handled = True
- # Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(self, cmd)
- handled = True
-
- if not handled:
+ if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
+ return
+
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
def connectionLost(self, reason):
logger.info("Lost connection to redis")
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index b21538766d..f016b4f1bd 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -17,8 +17,7 @@
"""
import logging
import re
-
-from twisted.internet import defer
+from typing import Iterable, Pattern
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -27,15 +26,23 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
-def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
+def client_patterns(
+ path_regex: str,
+ releases: Iterable[int] = (0,),
+ unstable: bool = True,
+ v1: bool = False,
+) -> Iterable[Pattern]:
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
- path_regex (str): The regex string to match. This should NOT have a ^
+ path_regex: The regex string to match. This should NOT have a ^
as this will be prefixed.
+ releases: An iterable of releases to include this endpoint under.
+ unstable: If true, include this endpoint under the "unstable" prefix.
+ v1: If true, include this endpoint under the "api/v1" prefix.
Returns:
- SRE_Pattern
+ An iterable of patterns.
"""
patterns = []
@@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
- Takes a on_POST method which returns a deferred (errcode, body) response
+ Takes a on_POST method which returns an Awaitable (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
# ...
- yield self.auth_handler.check_auth
- """
+ await self.auth_handler.check_auth
+ """
- def wrapped(*args, **kwargs):
- res = defer.ensureDeferred(orig(*args, **kwargs))
- res.addErrback(_catch_incomplete_interactive_auth)
- return res
+ async def wrapped(*args, **kwargs):
+ try:
+ return await orig(*args, **kwargs)
+ except InteractiveAuthIncompleteError as e:
+ return 401, e.result
return wrapped
-
-
-def _catch_incomplete_interactive_auth(f):
- """helper for interactive_auth_handler
-
- Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
-
- Args:
- f (failure.Failure):
- """
- f.trap(InteractiveAuthIncompleteError)
- return 401, f.value.result
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 595849f9d5..9a847130c0 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -18,7 +18,6 @@ import logging
import os
import urllib
-from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -77,8 +76,9 @@ def respond_404(request):
)
-@defer.inlineCallbacks
-def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
+async def respond_with_file(
+ request, media_type, file_path, file_size=None, upload_name=None
+):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+ await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
@@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x):
return True
-@defer.inlineCallbacks
-def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+async def respond_with_responder(
+ request, responder, media_type, file_size, upload_name=None
+):
"""Responds to the request with given responder. If responder is None then
returns 404.
@@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
- yield responder.write_to_consumer(request)
+ await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 79cb0dddbe..66bc1c3360 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -14,17 +14,18 @@
# limitations under the License.
import contextlib
+import inspect
import logging
import os
import shutil
+from typing import Optional
-from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
-from ._base import Responder
+from ._base import FileInfo, Responder
logger = logging.getLogger(__name__)
@@ -46,25 +47,24 @@ class MediaStorage(object):
self.filepaths = filepaths
self.storage_providers = storage_providers
- @defer.inlineCallbacks
- def store_file(self, source, file_info):
+ async def store_file(self, source, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Returns:
- Deferred[str]: the file path written to in the primary media store
+ the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- yield defer_to_thread(
+ await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
- yield finish_cb()
+ await finish_cb()
return fname
@@ -75,7 +75,7 @@ class MediaStorage(object):
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
- on disk, and finish_cb is a function that returns a Deferred.
+ on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
@@ -91,7 +91,7 @@ class MediaStorage(object):
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
- yield finish_cb()
+ await finish_cb()
"""
path = self._file_info_to_path(file_info)
@@ -103,10 +103,13 @@ class MediaStorage(object):
finished_called = [False]
- @defer.inlineCallbacks
- def finish():
+ async def finish():
for provider in self.storage_providers:
- yield provider.store_file(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = provider.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ await result
finished_called[0] = True
@@ -123,17 +126,15 @@ class MediaStorage(object):
if not finished_called:
raise Exception("Finished callback not called")
- @defer.inlineCallbacks
- def fetch_media(self, file_info):
+ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[Responder|None]: Returns a Responder if the file was found,
- otherwise None.
+ Returns a Responder if the file was found, otherwise None.
"""
path = self._file_info_to_path(file_info)
@@ -142,23 +143,26 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = provider.fetch(path, file_info)
+ # Fetch is supposed to return an Awaitable, but guard against
+ # improper implementations.
+ if inspect.isawaitable(res):
+ res = await res
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
return None
- @defer.inlineCallbacks
- def ensure_media_is_in_local_cache(self, file_info):
+ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[str]: Full path to local file
+ Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
@@ -170,14 +174,18 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = provider.fetch(path, file_info)
+ # Fetch is supposed to return an Awaitable, but guard against
+ # improper implementations.
+ if inspect.isawaitable(res):
+ res = await res
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
)
- yield res.write_to_consumer(consumer)
- yield consumer.wait()
+ await res.write_to_consumer(consumer)
+ await consumer.wait()
return local_path
raise Exception("file could not be found")
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e52c86c798..13d1a6d2ed 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -26,6 +26,7 @@ import traceback
from typing import Dict, Optional
from urllib import parse as urlparse
+import attr
from canonicaljson import json
from twisted.internet import defer
@@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
+ONE_HOUR = 60 * 60 * 1000
+
+# A map of globs to API endpoints.
+_oembed_globs = {
+ # Twitter.
+ "https://publish.twitter.com/oembed": [
+ "https://twitter.com/*/status/*",
+ "https://*.twitter.com/*/status/*",
+ "https://twitter.com/*/moments/*",
+ "https://*.twitter.com/*/moments/*",
+ # Include the HTTP versions too.
+ "http://twitter.com/*/status/*",
+ "http://*.twitter.com/*/status/*",
+ "http://twitter.com/*/moments/*",
+ "http://*.twitter.com/*/moments/*",
+ ],
+}
+# Convert the globs to regular expressions.
+_oembed_patterns = {}
+for endpoint, globs in _oembed_globs.items():
+ for glob in globs:
+ # Convert the glob into a sane regular expression to match against. The
+ # rules followed will be slightly different for the domain portion vs.
+ # the rest.
+ #
+ # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
+ # 2. The domain can have globs, but we limit it to characters that can
+ # reasonably be a domain part.
+ # TODO: This does not attempt to handle Unicode domain names.
+ # 3. Other parts allow a glob to be any one, or more, characters.
+ results = urlparse.urlparse(glob)
+
+ # Ensure the scheme does not have wildcards (and is a sane scheme).
+ if results.scheme not in {"http", "https"}:
+ raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
+
+ pattern = urlparse.urlunparse(
+ [
+ results.scheme,
+ re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+ ]
+ + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+ )
+ _oembed_patterns[re.compile(pattern)] = endpoint
+
+
+@attr.s
+class OEmbedResult:
+ # Either HTML content or URL must be provided.
+ html = attr.ib(type=Optional[str])
+ url = attr.ib(type=Optional[str])
+ title = attr.ib(type=Optional[str])
+ # Number of seconds to cache the content.
+ cache_age = attr.ib(type=int)
+
+
+class OEmbedError(Exception):
+ """An error occurred processing the oEmbed object."""
+
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
@@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource):
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
- expiry_ms=60 * 60 * 1000,
+ expiry_ms=ONE_HOUR,
)
if self._worker_run_media_background_jobs:
@@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
+ def _get_oembed_url(self, url: str) -> Optional[str]:
+ """
+ Check whether the URL should be downloaded as oEmbed content instead.
+
+ Params:
+ url: The URL to check.
+
+ Returns:
+ A URL to use instead or None if the original URL should be used.
+ """
+ for url_pattern, endpoint in _oembed_patterns.items():
+ if url_pattern.fullmatch(url):
+ return endpoint
+
+ # No match.
+ return None
+
+ async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ """
+ Request content from an oEmbed endpoint.
+
+ Params:
+ endpoint: The oEmbed API endpoint.
+ url: The URL to pass to the API.
+
+ Returns:
+ An object representing the metadata returned.
+
+ Raises:
+ OEmbedError if fetching or parsing of the oEmbed information fails.
+ """
+ try:
+ logger.debug("Trying to get oEmbed content for url '%s'", url)
+ result = await self.client.get_json(
+ endpoint,
+ # TODO Specify max height / width.
+ # Note that only the JSON format is supported.
+ args={"url": url},
+ )
+
+ # Ensure there's a version of 1.0.
+ if result.get("version") != "1.0":
+ raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+ oembed_type = result.get("type")
+
+ # Ensure the cache age is None or an int.
+ cache_age = result.get("cache_age")
+ if cache_age:
+ cache_age = int(cache_age)
+
+ oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+ # HTML content.
+ if oembed_type == "rich":
+ oembed_result.html = result.get("html")
+ return oembed_result
+
+ if oembed_type == "photo":
+ oembed_result.url = result.get("url")
+ return oembed_result
+
+ # TODO Handle link and video types.
+
+ if "thumbnail_url" in result:
+ oembed_result.url = result.get("thumbnail_url")
+ return oembed_result
+
+ raise OEmbedError("Incompatible oEmbed information.")
+
+ except OEmbedError as e:
+ # Trap OEmbedErrors first so we can directly re-raise them.
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+ raise
+
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+ raise OEmbedError() from e
+
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ # If this URL can be accessed via oEmbed, use that instead.
+ url_to_download = url
+ oembed_url = self._get_oembed_url(url)
+ if oembed_url:
+ # The result might be a new URL to download, or it might be HTML content.
try:
- logger.debug("Trying to get preview for url '%s'", url)
- length, headers, uri, code = await self.client.get_file(
- url,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url, e)
+ oembed_result = await self._get_oembed_content(oembed_url, url)
+ if oembed_result.url:
+ url_to_download = oembed_result.url
+ elif oembed_result.html:
+ url_to_download = None
+ except OEmbedError:
+ # If an error occurs, try doing a normal preview.
+ pass
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
+ if url_to_download:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ logger.debug("Trying to get preview for url '%s'", url_to_download)
+ length, headers, uri, code = await self.client.get_file(
+ url_to_download,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url_to_download, e)
+
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+ await finish()
+
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
+
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers["ETag"][0] if "ETag" in headers else None
+ else:
+ html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ f.write(html_bytes)
+ await finish()
+
+ media_type = "text/html"
+ download_name = oembed_result.title
+ length = len(html_bytes)
+ # If a specific cache age was not given, assume 1 hour.
+ expires = oembed_result.cache_age or ONE_HOUR
+ uri = oembed_url
+ code = 200
+ etag = None
try:
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
- else:
- media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
- download_name = get_filename_from_headers(headers)
-
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
- time_now_ms=self.clock.time_msec(),
+ time_now_ms=time_now_ms,
upload_name=download_name,
media_length=length,
user_id=user,
@@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource):
"filename": fname,
"uri": uri,
"response_code": code,
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- "expires": 60 * 60 * 1000,
- "etag": headers["ETag"][0] if "ETag" in headers else None,
+ "expires": expires,
+ "etag": etag,
}
def _start_expire_url_cache_data(self):
@@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * 24 * 60 * 60 * 1000
+ expire_before = now - 2 * 24 * ONE_HOUR
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 90a673778f..1aba408c21 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -31,6 +31,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
+from synapse.handlers.typing import FollowerTypingHandler
from synapse.replication.tcp.streams import Stream
class HomeServer(object):
@@ -150,3 +151,5 @@ class HomeServer(object):
pass
def should_send_federation(self) -> bool:
pass
+ def get_typing_handler(self) -> FollowerTypingHandler:
+ pass
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 495d9f04c8..25ccef5aa5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,14 +16,12 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set
+from typing import Awaitable, Dict, Iterable, List, Optional, Set
import attr
from frozendict import frozendict
from prometheus_client import Histogram
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
@@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
@@ -108,8 +107,7 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
- @defer.inlineCallbacks
- def get_current_state(
+ async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
@@ -126,20 +124,20 @@ class StateHandler(object):
map from (type, state_key) to event
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
- ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
return event
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
@@ -148,8 +146,7 @@ class StateHandler(object):
return state
- @defer.inlineCallbacks
- def get_current_state_ids(self, room_id, latest_event_ids=None):
+ async def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room
Args:
@@ -164,41 +161,38 @@ class StateHandler(object):
(event_type, state_key) -> event_id
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
- @defer.inlineCallbacks
- def get_current_users_in_room(self, room_id, latest_event_ids=None):
+ async def get_current_users_in_room(
+ self, room_id: str, latest_event_ids: Optional[List[str]] = None
+ ) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
Args:
- room_id (str): The ID of the room.
- latest_event_ids (List[str]|None): Precomputed list of latest
- event IDs. Will be computed if None.
+ room_id: The ID of the room.
+ latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
- profileinfo.
+ Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_users_in_room")
- entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
- joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
+ entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
- @defer.inlineCallbacks
- def get_current_hosts_in_room(self, room_id):
- event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
+ async def get_current_hosts_in_room(self, room_id):
+ event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ return await self.get_hosts_in_room_at_events(room_id, event_ids)
- @defer.inlineCallbacks
- def get_hosts_in_room_at_events(self, room_id, event_ids):
+ async def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids
Args:
@@ -208,12 +202,11 @@ class StateHandler(object):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
- entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
- joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
+ entry = await self.resolve_state_groups_for_events(room_id, event_ids)
+ joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
- @defer.inlineCallbacks
- def compute_event_context(
+ async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event.
@@ -278,7 +271,7 @@ class StateHandler(object):
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")
- entry = yield self.resolve_state_groups_for_events(
+ entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
@@ -295,7 +288,7 @@ class StateHandler(object):
#
if not state_group_before_event:
- state_group_before_event = yield self.state_store.store_state_group(
+ state_group_before_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
@@ -335,7 +328,7 @@ class StateHandler(object):
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
- state_group_after_event = yield self.state_store.store_state_group(
+ state_group_after_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
@@ -353,8 +346,7 @@ class StateHandler(object):
)
@measure_func()
- @defer.inlineCallbacks
- def resolve_state_groups_for_events(self, room_id, event_ids):
+ async def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -373,7 +365,7 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
- state_groups_ids = yield self.state_store.get_state_groups_ids(
+ state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)
@@ -382,7 +374,7 @@ class StateHandler(object):
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
- prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
+ prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
@@ -391,9 +383,9 @@ class StateHandler(object):
delta_ids=delta_ids,
)
- room_version = yield self.store.get_room_version_id(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
- result = yield self._state_resolution_handler.resolve_state_groups(
+ result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
@@ -402,8 +394,7 @@ class StateHandler(object):
)
return result
- @defer.inlineCallbacks
- def resolve_events(self, room_version, state_sets, event):
+ async def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -414,7 +405,7 @@ class StateHandler(object):
state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_store(
+ new_state = await resolve_events_with_store(
self.clock,
event.room_id,
room_version,
@@ -451,9 +442,8 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True,
)
- @defer.inlineCallbacks
@log_function
- def resolve_state_groups(
+ async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
@@ -479,13 +469,13 @@ class StateResolutionHandler(object):
state_res_store (StateResolutionStore)
Returns:
- Deferred[_StateCacheEntry]: resolved state
+ _StateCacheEntry: resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
- with (yield self.resolve_linearizer.queue(group_names)):
+ with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
@@ -517,7 +507,7 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_store(
+ new_state = await resolve_events_with_store(
self.clock,
room_id,
room_version,
@@ -598,7 +588,7 @@ def resolve_events_with_store(
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
-):
+) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
@@ -619,8 +609,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 7b531a8337..ab5e24841d 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,9 +15,7 @@
import hashlib
import logging
-from typing import Callable, Dict, List, Optional
-
-from twisted.internet import defer
+from typing import Awaitable, Callable, Dict, List, Optional
from synapse import event_auth
from synapse.api.constants import EventTypes
@@ -32,12 +30,11 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
-@defer.inlineCallbacks
-def resolve_events_with_store(
+async def resolve_events_with_store(
room_id: str,
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable,
+ state_map_factory: Callable[[List[str]], Awaitable],
):
"""
Args:
@@ -56,7 +53,7 @@ def resolve_events_with_store(
state_map_factory: will be called
with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event.
+ an Awaitable that resolves to a dict of event_id to event.
Returns:
Deferred[dict[(str, str), str]]:
@@ -80,7 +77,7 @@ def resolve_events_with_store(
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
- state_map = yield state_map_factory(needed_events)
+ state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@@ -110,7 +107,7 @@ def resolve_events_with_store(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
- state_map_new = yield state_map_factory(new_needed_events)
+ state_map_new = await state_map_factory(new_needed_events)
for event in state_map_new.values():
if event.room_id != room_id:
raise Exception(
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index bf6caa0946..6634955cdc 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -18,8 +18,6 @@ import itertools
import logging
from typing import Dict, List, Optional
-from twisted.internet import defer
-
import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
@@ -32,14 +30,13 @@ from synapse.util import Clock
logger = logging.getLogger(__name__)
-# We want to yield to the reactor occasionally during state res when dealing
+# We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by
-# yielding to reactor during loops every N iterations.
-_YIELD_AFTER_ITERATIONS = 100
+# awaiting to reactor during loops every N iterations.
+_AWAIT_AFTER_ITERATIONS = 100
-@defer.inlineCallbacks
-def resolve_events_with_store(
+async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
@@ -87,7 +84,7 @@ def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
+ auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
full_conflicted_set = set(
itertools.chain(
@@ -95,7 +92,7 @@ def resolve_events_with_store(
)
)
- events = yield state_res_store.get_events(
+ events = await state_res_store.get_events(
[eid for eid in full_conflicted_set if eid not in event_map],
allow_rejected=True,
)
@@ -118,14 +115,14 @@ def resolve_events_with_store(
eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
- sorted_power_events = yield _reverse_topological_power_sort(
+ sorted_power_events = await _reverse_topological_power_sort(
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
- resolved_state = yield _iterative_auth_checks(
+ resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@@ -148,13 +145,13 @@ def resolve_events_with_store(
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
- leftover_events = yield _mainline_sort(
+ leftover_events = await _mainline_sort(
clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
- resolved_state = yield _iterative_auth_checks(
+ resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
@@ -174,8 +171,7 @@ def resolve_events_with_store(
return resolved_state
-@defer.inlineCallbacks
-def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
+async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to
their auth events.
@@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
Returns:
Deferred[int]
"""
- event = yield _get_event(room_id, event_id, event_map, state_res_store)
+ event = await _get_event(room_id, event_id, event_map, state_res_store)
pl = None
for aid in event.auth_event_ids():
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids():
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
@@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
return int(level)
-@defer.inlineCallbacks
-def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
@@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Deferred[set[str]]: Set of event IDs
"""
- difference = yield state_res_store.get_auth_chain_difference(
+ difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)
@@ -292,8 +287,7 @@ def _is_power_event(event):
return False
-@defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(
+async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
):
"""Helper function for _reverse_topological_power_sort that add the event
@@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop()
graph.setdefault(eid, set())
- event = yield _get_event(room_id, eid, event_map, state_res_store)
+ event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
if aid in auth_diff:
if aid not in graph:
@@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph(
graph.setdefault(eid, set()).add(aid)
-@defer.inlineCallbacks
-def _reverse_topological_power_sort(
+async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering,
@@ -344,26 +337,26 @@ def _reverse_topological_power_sort(
graph = {}
for idx, event_id in enumerate(event_ids, start=1):
- yield _add_event_and_auth_chain_to_graph(
+ await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
event_to_pl = {}
for idx, event_id in enumerate(graph, start=1):
- pl = yield _get_power_level_for_sender(
+ pl = await _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
def _get_power_order(event_id):
ev = event_map[event_id]
@@ -378,8 +371,7 @@ def _reverse_topological_power_sort(
return sorted_events
-@defer.inlineCallbacks
-def _iterative_auth_checks(
+async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
"""Sequentially apply auth checks to each event in given list, updating the
@@ -405,7 +397,7 @@ def _iterative_auth_checks(
auth_events = {}
for aid in event.auth_event_ids():
- ev = yield _get_event(
+ ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
@@ -420,7 +412,7 @@ def _iterative_auth_checks(
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
- ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
+ ev = await _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
@@ -438,16 +430,15 @@ def _iterative_auth_checks(
except AuthError:
pass
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
return resolved_state
-@defer.inlineCallbacks
-def _mainline_sort(
+async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
@@ -474,21 +465,21 @@ def _mainline_sort(
idx = 0
while pl:
mainline.append(pl)
- pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
+ pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids()
pl = None
for aid in auth_events:
- ev = yield _get_event(
+ ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
idx += 1
@@ -498,23 +489,24 @@ def _mainline_sort(
order_map = {}
for idx, ev_id in enumerate(event_ids, start=1):
- depth = yield _get_mainline_depth_for_event(
+ depth = await _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
- # We yield occasionally when we're working with large data sets to
+ # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
- if idx % _YIELD_AFTER_ITERATIONS == 0:
- yield clock.sleep(0)
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids
-@defer.inlineCallbacks
-def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
+async def _get_mainline_depth_for_event(
+ event, mainline_map, event_map, state_res_store
+):
"""Get the mainline depths for the given event based on the mainline map
Args:
@@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
event = None
for aid in auth_events:
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
return 0
-@defer.inlineCallbacks
-def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
+async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
"""Helper function to look up event in event_map, falling back to looking
it up in the store
@@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
Deferred[Optional[FrozenEvent]]
"""
if event_id not in event_map:
- events = yield state_res_store.get_events([event_id], allow_rejected=True)
+ events = await state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
event = event_map.get(event_id)
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index d181488db7..c229248101 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -259,7 +259,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 29765890ee..a92e401e88 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 6b8130bf0f..942e51fd3a 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id
)
- users_with_profile = yield state.get_current_users_in_room(room_id)
+ users_with_profile = yield defer.ensureDeferred(
+ state.get_current_users_in_room(room_id)
+ )
user_ids = set(users_with_profile)
# Update each user in the user directory.
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa46041676..78fbdcdee8 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -29,7 +29,6 @@ from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap
@@ -648,6 +647,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
+
+ # Avoid a circular import.
+ from synapse.state import StateResolutionStore
+
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 1a9bd5f37d..d1bd18da39 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -26,21 +26,24 @@ from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
+from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
+ # Ensure a new Awaitable is created for each call.
+ mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ ["test", "host2"]
+ )
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]),
)
@override_config({"send_federation": True})
def test_send_receipts(self):
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out get_current_hosts_in_room
- mock_state_handler = hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 946f06d151..ba8552c29f 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1,1447 +1,1447 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 Dirk Klimpel
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import urllib.parse
-from typing import List, Optional
-
-from mock import Mock
-
-import synapse.rest.admin
-from synapse.api.errors import Codes
-from synapse.rest.client.v1 import directory, events, login, room
-
-from tests import unittest
-
-"""Tests admin REST events for /rooms paths."""
-
-
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
-class DeleteRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_tok = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- self.room_id = self.helper.create_room_as(
- self.other_user, tok=self.other_user_tok
- )
- self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
-
- def test_requester_is_no_admin(self):
- """
- If the user is not a server admin, an error 403 is returned.
- """
-
- request, channel = self.make_request(
- "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_room_does_not_exist(self):
- """
- Check that unknown rooms/server return error 404.
- """
- url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
-
- request, channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_room_is_not_valid(self):
- """
- Check that invalid room names, return an error 400.
- """
- url = "/_synapse/admin/v1/rooms/invalidroom/delete"
-
- request, channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "invalidroom is not a legal room ID", channel.json_body["error"],
- )
-
- def test_new_room_user_does_not_exist(self):
- """
- Tests that the user ID must be from local server but it does not have to exist.
- """
- body = json.dumps({"new_room_user_id": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertIn("new_room_id", channel.json_body)
- self.assertIn("kicked_users", channel.json_body)
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- def test_new_room_user_is_not_local(self):
- """
- Check that only local users can create new room to move members.
- """
- body = json.dumps({"new_room_user_id": "@not:exist.bla"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "User must be our own: @not:exist.bla", channel.json_body["error"],
- )
-
- def test_block_is_not_bool(self):
- """
- If parameter `block` is not boolean, return an error
- """
- body = json.dumps({"block": "NotBool"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
-
- def test_purge_room_and_block(self):
- """Test to purge a room and block it.
- Members will not be moved to a new room and will not receive a message.
- """
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Test that room is not blocked
- self._is_blocked(self.room_id, expect=False)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- body = json.dumps({"block": True})
-
- request, channel = self.make_request(
- "POST",
- self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(None, channel.json_body["new_room_id"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- self._is_purged(self.room_id)
- self._is_blocked(self.room_id, expect=True)
- self._has_no_members(self.room_id)
-
- def test_purge_room_and_not_block(self):
- """Test to purge a room and do not block it.
- Members will not be moved to a new room and will not receive a message.
- """
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Test that room is not blocked
- self._is_blocked(self.room_id, expect=False)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- body = json.dumps({"block": False})
-
- request, channel = self.make_request(
- "POST",
- self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(None, channel.json_body["new_room_id"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- self._is_purged(self.room_id)
- self._is_blocked(self.room_id, expect=False)
- self._has_no_members(self.room_id)
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- Members will be moved to a new room and will receive a message.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
- )
-
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("new_room_id", channel.json_body)
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- # Test that member has moved to new room
- self._is_member(
- room_id=channel.json_body["new_room_id"], user_id=self.other_user
- )
-
- self._is_purged(self.room_id)
- self._has_no_members(self.room_id)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- Members will be moved to a new room and will receive a message.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("new_room_id", channel.json_body)
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- # Test that member has moved to new room
- self._is_member(
- room_id=channel.json_body["new_room_id"], user_id=self.other_user
- )
-
- self._is_purged(self.room_id)
- self._has_no_members(self.room_id)
-
- # Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=403)
-
- def _is_blocked(self, room_id, expect=True):
- """Assert that the room is blocked or not
- """
- d = self.store.is_room_blocked(room_id)
- if expect:
- self.assertTrue(self.get_success(d))
- else:
- self.assertIsNone(self.get_success(d))
-
- def _has_no_members(self, room_id):
- """Assert there is now no longer anyone in the room
- """
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def _is_member(self, room_id, user_id):
- """Test that user is member of the room
- """
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertIn(user_id, users_in_room)
-
- def _is_purged(self, room_id):
- """Test that the following tables have been purged of all rows related to the room.
- """
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
-
-class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- directory.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- # Create user
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_list_rooms(self):
- """Test that we can list rooms"""
- # Create 3 test rooms
- total_rooms = 3
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- # Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
-
- # Check that response json body contains a "rooms" key
- self.assertTrue(
- "rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
- )
-
- # Check that 3 rooms were returned
- self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
-
- # Check their room_ids match
- returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
- self.assertEqual(room_ids, returned_room_ids)
-
- # Check that all fields are available
- for r in channel.json_body["rooms"]:
- self.assertIn("name", r)
- self.assertIn("canonical_alias", r)
- self.assertIn("joined_members", r)
- self.assertIn("joined_local_members", r)
- self.assertIn("version", r)
- self.assertIn("creator", r)
- self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
- self.assertIn("join_rules", r)
- self.assertIn("guest_access", r)
- self.assertIn("history_visibility", r)
- self.assertIn("state_events", r)
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # Should be 0 as we aren't paginating
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that the prev_batch parameter is not present
- self.assertNotIn("prev_batch", channel.json_body)
-
- # We shouldn't receive a next token here as there's no further rooms to show
- self.assertNotIn("next_batch", channel.json_body)
-
- def test_list_rooms_pagination(self):
- """Test that we can get a full list of rooms through pagination"""
- # Create 5 test rooms
- total_rooms = 5
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Set the name of the rooms so we get a consistent returned ordering
- for idx, room_id in enumerate(room_ids):
- self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- returned_room_ids = []
- start = 0
- limit = 2
-
- run_count = 0
- should_repeat = True
- while should_repeat:
- run_count += 1
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
- start,
- limit,
- "name",
- )
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- self.assertTrue("rooms" in channel.json_body)
- for r in channel.json_body["rooms"]:
- returned_room_ids.append(r["room_id"])
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # We're only getting 2 rooms each page, so should be 2 * last run_count
- self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
-
- if run_count > 1:
- # Check the value of prev_batch is correct
- self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
-
- if "next_batch" not in channel.json_body:
- # We have reached the end of the list
- should_repeat = False
- else:
- # Make another query with an updated start value
- start = channel.json_body["next_batch"]
-
- # We should've queried the endpoint 3 times
- self.assertEqual(
- run_count,
- 3,
- msg="Should've queried 3 times for 5 rooms with limit 2 per query",
- )
-
- # Check that we received all of the room ids
- self.assertEqual(room_ids, returned_room_ids)
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- def test_correct_room_attributes(self):
- """Test the correct attributes for a room are returned"""
- # Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- test_alias = "#test:test"
- test_room_name = "something"
-
- # Have another user join the room
- user_2 = self.register_user("user4", "pass")
- user_tok_2 = self.login("user4", "pass")
- self.helper.join(room_id, user_2, tok=user_tok_2)
-
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=self.admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=self.admin_user_tok,
- )
-
- # Set a name for the room
- self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that only one room was returned
- self.assertEqual(len(rooms), 1)
-
- # And that the value of the total_rooms key was correct
- self.assertEqual(channel.json_body["total_rooms"], 1)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that all provided attributes are set
- r = rooms[0]
- self.assertEqual(room_id, r["room_id"])
- self.assertEqual(test_room_name, r["name"])
- self.assertEqual(test_alias, r["canonical_alias"])
-
- def test_room_list_sort_order(self):
- """Test room list sort ordering. alphabetical name versus number of members,
- reversing the order, etc.
- """
-
- def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (
- urllib.parse.quote(test_alias),
- )
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=admin_user_tok,
- )
-
- def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
- ):
- """Request the list of rooms in a certain order. Assert that order is what
- we expect
-
- Args:
- order_type: The type of ordering to give the server
- expected_room_list: The list of room_ids in the order we expect to get
- back from the server
- """
- # Request the list of rooms in the given order
- url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
- if reverse:
- url += "&dir=b"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check for the correct total_rooms value
- self.assertEqual(channel.json_body["total_rooms"], 3)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that rooms were returned in alphabetical order
- returned_order = [r["room_id"] for r in rooms]
- self.assertListEqual(expected_room_list, returned_order) # order is checked
-
- # Create 3 test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
- )
-
- # Set room canonical room aliases
- _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
-
- # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
- user_1 = self.register_user("bob1", "pass")
- user_1_tok = self.login("bob1", "pass")
- self.helper.join(room_id_2, user_1, tok=user_1_tok)
-
- user_2 = self.register_user("bob2", "pass")
- user_2_tok = self.login("bob2", "pass")
- self.helper.join(room_id_3, user_2, tok=user_2_tok)
-
- user_3 = self.register_user("bob3", "pass")
- user_3_tok = self.login("bob3", "pass")
- self.helper.join(room_id_3, user_3, tok=user_3_tok)
-
- # Test different sort orders, with forward and reverse directions
- _order_test("name", [room_id_1, room_id_2, room_id_3])
- _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
- _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
- _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
- _order_test(
- "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("version", [room_id_1, room_id_2, room_id_3])
- _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("creator", [room_id_1, room_id_2, room_id_3])
- _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("encryption", [room_id_1, room_id_2, room_id_3])
- _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("federatable", [room_id_1, room_id_2, room_id_3])
- _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("public", [room_id_1, room_id_2, room_id_3])
- # Different sort order of SQlite and PostreSQL
- # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
- _order_test(
- "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("state_events", [room_id_3, room_id_2, room_id_1])
- _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- def test_search_term(self):
- """Test that searching for a room works correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- def _search_test(
- expected_room_id: Optional[str],
- search_term: str,
- expected_http_code: int = 200,
- ):
- """Search for a room and check that the returned room's id is a match
-
- Args:
- expected_room_id: The room_id expected to be returned by the API. Set
- to None to expect zero results for the search
- search_term: The term to search for room names with
- expected_http_code: The expected http code for the request
- """
- url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
-
- if expected_http_code != 200:
- return
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that the expected number of rooms were returned
- expected_room_count = 1 if expected_room_id else 0
- self.assertEqual(len(rooms), expected_room_count)
- self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- if expected_room_id:
- # Check that the first returned room id is correct
- r = rooms[0]
- self.assertEqual(expected_room_id, r["room_id"])
-
- # Perform search tests
- _search_test(room_id_1, "something")
- _search_test(room_id_1, "thing")
-
- _search_test(room_id_2, "else")
- _search_test(room_id_2, "se")
-
- _search_test(None, "foo")
- _search_test(None, "bar")
- _search_test(None, "", expected_http_code=400)
-
- def test_single_room(self):
- """Test that a single room can be requested correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertIn("room_id", channel.json_body)
- self.assertIn("name", channel.json_body)
- self.assertIn("canonical_alias", channel.json_body)
- self.assertIn("joined_members", channel.json_body)
- self.assertIn("joined_local_members", channel.json_body)
- self.assertIn("version", channel.json_body)
- self.assertIn("creator", channel.json_body)
- self.assertIn("encryption", channel.json_body)
- self.assertIn("federatable", channel.json_body)
- self.assertIn("public", channel.json_body)
- self.assertIn("join_rules", channel.json_body)
- self.assertIn("guest_access", channel.json_body)
- self.assertIn("history_visibility", channel.json_body)
- self.assertIn("state_events", channel.json_body)
-
- self.assertEqual(room_id_1, channel.json_body["room_id"])
-
- def test_room_members(self):
- """Test that room members can be requested correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Have another user join the room
- user_1 = self.register_user("foo", "pass")
- user_tok_1 = self.login("foo", "pass")
- self.helper.join(room_id_1, user_1, tok=user_tok_1)
-
- # Have another user join the room
- user_2 = self.register_user("bar", "pass")
- user_tok_2 = self.login("bar", "pass")
- self.helper.join(room_id_1, user_2, tok=user_tok_2)
- self.helper.join(room_id_2, user_2, tok=user_tok_2)
-
- # Have another user join the room
- user_3 = self.register_user("foobar", "pass")
- user_tok_3 = self.login("foobar", "pass")
- self.helper.join(room_id_2, user_3, tok=user_tok_3)
-
- url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertCountEqual(
- ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
- )
- self.assertEqual(channel.json_body["total"], 3)
-
- url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertCountEqual(
- ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
- )
- self.assertEqual(channel.json_body["total"], 3)
-
-
-class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
- servlets = [
- synapse.rest.admin.register_servlets,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor, clock, homeserver):
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.creator = self.register_user("creator", "test")
- self.creator_tok = self.login("creator", "test")
-
- self.second_user_id = self.register_user("second", "test")
- self.second_tok = self.login("second", "test")
-
- self.public_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=True
- )
- self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
-
- def test_requester_is_no_admin(self):
- """
- If the user is not a server admin, an error 403 is returned.
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.second_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_invalid_parameter(self):
- """
- If a parameter is missing, return an error
- """
- body = json.dumps({"unknown_parameter": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
- def test_local_user_does_not_exist(self):
- """
- Tests that a lookup for a user that does not exist returns a 404
- """
- body = json.dumps({"user_id": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_remote_user(self):
- """
- Check that only local user can join rooms.
- """
- body = json.dumps({"user_id": "@not:exist.bla"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "This endpoint can only be used with local users",
- channel.json_body["error"],
- )
-
- def test_room_does_not_exist(self):
- """
- Check that unknown rooms/server return error 404.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/!unknown:test"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("No known servers", channel.json_body["error"])
-
- def test_room_is_not_valid(self):
- """
- Check that invalid room names, return an error 400.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/invalidroom"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "invalidroom was not legal room ID or room alias",
- channel.json_body["error"],
- )
-
- def test_join_public_room(self):
- """
- Test joining a local user to a public room with "JoinRules.PUBLIC"
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_not_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE"
- when server admin is not member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_join_private_room_if_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- self.helper.invite(
- room=private_room_id,
- src=self.creator,
- targ=self.admin_user,
- tok=self.creator_tok,
- )
- self.helper.join(
- room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
- )
-
- # Validate if server admin is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- # Join user to room.
-
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_owner(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is owner of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
+ )
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert there is now no longer anyone in the room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class DeleteRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ url = "/_synapse/admin/v1/rooms/invalidroom/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom is not a legal room ID", channel.json_body["error"],
+ )
+
+ def test_new_room_user_does_not_exist(self):
+ """
+ Tests that the user ID must be from local server but it does not have to exist.
+ """
+ body = json.dumps({"new_room_user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("kicked_users", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ def test_new_room_user_is_not_local(self):
+ """
+ Check that only local users can create new room to move members.
+ """
+ body = json.dumps({"new_room_user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "User must be our own: @not:exist.bla", channel.json_body["error"],
+ )
+
+ def test_block_is_not_bool(self):
+ """
+ If parameter `block` is not boolean, return an error
+ """
+ body = json.dumps({"block": "NotBool"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_room_and_block(self):
+ """Test to purge a room and block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": True})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_purge_room_and_not_block(self):
+ """Test to purge a room and do not block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": False})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(self.room_id, expect_code=403)
+
+ def _is_blocked(self, room_id, expect=True):
+ """Assert that the room is blocked or not
+ """
+ d = self.store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _has_no_members(self, room_id):
+ """Assert there is now no longer anyone in the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def _is_member(self, room_id, user_id):
+ """Test that user is member of the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertIn(user_id, users_in_room)
+
+ def _is_purged(self, room_id):
+ """Test that the following tables have been purged of all rows related to the room.
+ """
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+ self.assertIn("joined_local_members", r)
+ self.assertIn("version", r)
+ self.assertIn("creator", r)
+ self.assertIn("encryption", r)
+ self.assertIn("federatable", r)
+ self.assertIn("public", r)
+ self.assertIn("join_rules", r)
+ self.assertIn("guest_access", r)
+ self.assertIn("history_visibility", r)
+ self.assertIn("state_events", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "name",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical name versus number of members,
+ reversing the order, etc.
+ """
+
+ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (
+ urllib.parse.quote(test_alias),
+ )
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room canonical room aliases
+ _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("name", [room_id_1, room_id_2, room_id_3])
+ _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+ _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+ _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+ _order_test(
+ "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("version", [room_id_1, room_id_2, room_id_3])
+ _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("creator", [room_id_1, room_id_2, room_id_3])
+ _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("public", [room_id_1, room_id_2, room_id_3])
+ # Different sort order of SQlite and PostreSQL
+ # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+ _order_test(
+ "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+ _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
+
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+ def test_room_members(self):
+ """Test that room members can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ # Have another user join the room
+ user_2 = self.register_user("bar", "pass")
+ user_tok_2 = self.login("bar", "pass")
+ self.helper.join(room_id_1, user_2, tok=user_tok_2)
+ self.helper.join(room_id_2, user_2, tok=user_tok_2)
+
+ # Have another user join the room
+ user_3 = self.register_user("foobar", "pass")
+ user_tok_3 = self.login("foobar", "pass")
+ self.helper.join(room_id_2, user_3, tok=user_tok_3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.second_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If a parameter is missing, return an error
+ """
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_local_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ body = json.dumps({"user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_remote_user(self):
+ """
+ Check that only local user can join rooms.
+ """
+ body = json.dumps({"user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "This endpoint can only be used with local users",
+ channel.json_body["error"],
+ )
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/!unknown:test"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("No known servers", channel.json_body["error"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/invalidroom"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom was not legal room ID or room alias",
+ channel.json_body["error"],
+ )
+
+ def test_join_public_room(self):
+ """
+ Test joining a local user to a public room with "JoinRules.PUBLIC"
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_not_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE"
+ when server admin is not member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_join_private_room_if_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ self.helper.invite(
+ room=private_room_id,
+ src=self.creator,
+ targ=self.admin_user,
+ tok=self.creator_tok,
+ )
+ self.helper.join(
+ room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+ )
+
+ # Validate if server admin is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ # Join user to room.
+
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_owner(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is owner of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 66fa5978b2..f4f3e56777 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -26,6 +26,7 @@ import attr
from parameterized import parameterized_class
from PIL import Image as Image
+from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
# This uses a real blocking threadpool so we have to wait for it to be
# actually done :/
- x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+ x = defer.ensureDeferred(
+ self.media_storage.ensure_media_is_in_local_cache(file_info)
+ )
# Hotloop until the threadpool does its job...
self.wait_on_thread(x)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 2826211f32..74765a582b 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import os
+import re
+
+from mock import patch
import attr
@@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver()
def test_cache_returns_correct_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://matrix.org", shorthand=False
@@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
def test_non_ascii_preview_httpequiv(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_non_ascii_preview_content_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_overlong_title(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
IP addresses can be previewed directly.
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://example.com", shorthand=False
@@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Hardcode the URL resolving to the IP we want.
self.lookups["example.com"] = [
(IPv4Address, "1.1.1.2"),
- (IPv4Address, "8.8.8.8"),
+ (IPv4Address, "10.1.2.3"),
]
request, channel = self.make_request(
@@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Accept-Language header is sent to the remote server
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
# Build and make a request to the server
request, channel = self.make_request(
@@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase):
),
server.data,
)
+
+ def test_oembed_photo(self):
+ """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
+
+ end_content = (
+ b"<html><head>"
+ b"<title>Some Title</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+
+ self.pump()
+
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ )
+
+ def test_oembed_rich(self):
+ """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"og:title": None, "og:description": "Content Preview"},
+ )
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 38f9b423ef..f2955a9c69 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,6 +14,7 @@
# limitations under the License.
import itertools
+from typing import List
import attr
@@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- state_before = self.successResultOf(state_d)
+ state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map),
)
- state = self.successResultOf(state_d)
+ state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state)
@@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
- return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ return defer.succeed(
+ {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ )
- def _get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
+ event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ List of event IDs of the auth chain.
"""
# Simple DFS for auth chain
@@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
- return set(chains[0]).union(*chains[1:]) - common
+ return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index b1dceb2918..1d77b4a2d6 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
@@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
diff --git a/tests/test_server.py b/tests/test_server.py
index 42cada8964..073b2362cc 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -193,10 +193,10 @@ class OptionsResourceTests(unittest.TestCase):
return channel
def test_unknown_options_request(self):
- """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ """An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -213,10 +213,10 @@ class OptionsResourceTests(unittest.TestCase):
)
def test_known_options_request(self):
- """An OPTIONS requests to an known URL still returns 200 OK."""
+ """An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..4858e8fc59 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return state_group
+ return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
- return {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
+ return defer.succeed(
+ {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
+ )
def get_state_group_delta(self, name):
- return None, None
+ return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
- return RoomVersions.V1.identifier
+ return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids()
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
+ @defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = self.store.store_state_group(
+ sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = self.store.store_state_group(
+ sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
- return self.state.compute_event_context(event)
+ result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ return result
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
"""
Utilities for running the unit tests
"""
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+ """Create an awaitable that just returns a result."""
+ return result
diff --git a/tox.ini b/tox.ini
index 834d68aea5..595ab3ba66 100644
--- a/tox.ini
+++ b/tox.ini
@@ -185,6 +185,7 @@ commands = mypy \
synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \
synapse/handlers/federation.py \
+ synapse/handlers/identity.py \
synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
synapse/handlers/room_member.py \
|