diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db
index f20567ba73..361369a581 100644
--- a/.buildkite/test_db.db
+++ b/.buildkite/test_db.db
Binary files differdiff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist
index fd98cbbaf6..5975cb98cf 100644
--- a/.buildkite/worker-blacklist
+++ b/.buildkite/worker-blacklist
@@ -1,41 +1,10 @@
# This file serves as a blacklist for SyTest tests that we expect will fail in
# Synapse when run under worker mode. For more details, see sytest-blacklist.
-Message history can be paginated
-
Can re-join room if re-invited
-The only membership state included in an initial sync is for all the senders in the timeline
-
-Local device key changes get to remote servers
-
-If remote user leaves room we no longer receive device updates
-
-Forgotten room messages cannot be paginated
-
-Inbound federation can get public room list
-
-Members from the gap are included in gappy incr LL sync
-
-Leaves are present in non-gapped incremental syncs
-
-Old leaves are present in gapped incremental syncs
-
-User sees updates to presence from other users in the incremental sync.
-
-Gapped incremental syncs include all state changes
-
-Old members are included in gappy incr LL sync if they start speaking
-
# new failures as of https://github.com/matrix-org/sytest/pull/732
Device list doesn't change if remote server is down
-Remote servers cannot set power levels in rooms without existing powerlevels
-Remote servers should reject attempts by non-creators to set the power levels
# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
Server correctly handles incoming m.device_list_update
-
-# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536
-Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
-
-Can get rooms/{roomId}/members at a given point
diff --git a/CHANGES.md b/CHANGES.md
index b44248e264..dfdd8aa68a 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,196 @@
+Synapse 1.21.0rc3 (2020-10-08)
+==============================
+
+Bugfixes
+--------
+
+- Fix duplication of events on high traffic servers, caused by PostgreSQL `could not serialize access due to concurrent update` errors. ([\#8456](https://github.com/matrix-org/synapse/issues/8456))
+
+
+Internal Changes
+----------------
+
+- Add Groovy Gorilla to the list of distributions we build `.deb`s for. ([\#8475](https://github.com/matrix-org/synapse/issues/8475))
+
+
+Synapse 1.21.0rc2 (2020-10-02)
+==============================
+
+Features
+--------
+
+- Convert additional templates from inline HTML to Jinja2 templates. ([\#8444](https://github.com/matrix-org/synapse/issues/8444))
+
+Bugfixes
+--------
+
+- Fix a regression in v1.21.0rc1 which broke thumbnails of remote media. ([\#8438](https://github.com/matrix-org/synapse/issues/8438))
+- Do not expose the experimental `uk.half-shot.msc2778.login.application_service` flow in the login API, which caused a compatibility problem with Element iOS. ([\#8440](https://github.com/matrix-org/synapse/issues/8440))
+- Fix malformed log line in new federation "catch up" logic. ([\#8442](https://github.com/matrix-org/synapse/issues/8442))
+- Fix DB query on startup for negative streams which caused long start up times. Introduced in [\#8374](https://github.com/matrix-org/synapse/issues/8374). ([\#8447](https://github.com/matrix-org/synapse/issues/8447))
+
+
+Synapse 1.21.0rc1 (2020-10-01)
+==============================
+
+Features
+--------
+
+- Require the user to confirm that their password should be reset after clicking the email confirmation link. ([\#8004](https://github.com/matrix-org/synapse/issues/8004))
+- Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel. ([\#8217](https://github.com/matrix-org/synapse/issues/8217))
+- Consolidate the SSO error template across all configuration. ([\#8248](https://github.com/matrix-org/synapse/issues/8248), [\#8405](https://github.com/matrix-org/synapse/issues/8405))
+- Add a configuration option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number. ([\#8275](https://github.com/matrix-org/synapse/issues/8275), [\#8417](https://github.com/matrix-org/synapse/issues/8417))
+- Add experimental support for sharding event persister. ([\#8294](https://github.com/matrix-org/synapse/issues/8294), [\#8387](https://github.com/matrix-org/synapse/issues/8387), [\#8396](https://github.com/matrix-org/synapse/issues/8396), [\#8419](https://github.com/matrix-org/synapse/issues/8419))
+- Add the room topic and avatar to the room details admin API. ([\#8305](https://github.com/matrix-org/synapse/issues/8305))
+- Add an admin API for querying rooms where a user is a member. Contributed by @dklimpel. ([\#8306](https://github.com/matrix-org/synapse/issues/8306))
+- Add `uk.half-shot.msc2778.login.application_service` login type to allow appservices to login. ([\#8320](https://github.com/matrix-org/synapse/issues/8320))
+- Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang. ([\#8345](https://github.com/matrix-org/synapse/issues/8345))
+- Add prometheus metrics for replication requests. ([\#8406](https://github.com/matrix-org/synapse/issues/8406))
+- Support passing additional single sign-on parameters to the client. ([\#8413](https://github.com/matrix-org/synapse/issues/8413))
+- Add experimental reporting of metrics on expensive rooms for state-resolution. ([\#8420](https://github.com/matrix-org/synapse/issues/8420))
+- Add experimental prometheus metric to track numbers of "large" rooms for state resolutiom. ([\#8425](https://github.com/matrix-org/synapse/issues/8425))
+- Add prometheus metrics to track federation delays. ([\#8430](https://github.com/matrix-org/synapse/issues/8430))
+
+
+Bugfixes
+--------
+
+- Fix a bug in the media repository where remote thumbnails with the same size but different crop methods would overwrite each other. Contributed by @deepbluev7. ([\#7124](https://github.com/matrix-org/synapse/issues/7124))
+- Fix inconsistent handling of non-existent push rules, and stop tracking the `enabled` state of removed push rules. ([\#7796](https://github.com/matrix-org/synapse/issues/7796))
+- Fix a longstanding bug when storing a media file with an empty `upload_name`. ([\#7905](https://github.com/matrix-org/synapse/issues/7905))
+- Fix messages not being sent over federation until an event is sent into the same room. ([\#8230](https://github.com/matrix-org/synapse/issues/8230), [\#8247](https://github.com/matrix-org/synapse/issues/8247), [\#8258](https://github.com/matrix-org/synapse/issues/8258), [\#8272](https://github.com/matrix-org/synapse/issues/8272), [\#8322](https://github.com/matrix-org/synapse/issues/8322))
+- Fix a longstanding bug where files that could not be thumbnailed would result in an Internal Server Error. ([\#8236](https://github.com/matrix-org/synapse/issues/8236), [\#8435](https://github.com/matrix-org/synapse/issues/8435))
+- Upgrade minimum version of `canonicaljson` to version 1.4.0, to fix an unicode encoding issue. ([\#8262](https://github.com/matrix-org/synapse/issues/8262))
+- Fix longstanding bug which could lead to incomplete database upgrades on SQLite. ([\#8265](https://github.com/matrix-org/synapse/issues/8265))
+- Fix stack overflow when stderr is redirected to the logging system, and the logging system encounters an error. ([\#8268](https://github.com/matrix-org/synapse/issues/8268))
+- Fix a bug which cause the logging system to report errors, if `DEBUG` was enabled and no `context` filter was applied. ([\#8278](https://github.com/matrix-org/synapse/issues/8278))
+- Fix edge case where push could get delayed for a user until a later event was pushed. ([\#8287](https://github.com/matrix-org/synapse/issues/8287))
+- Fix fetching malformed events from remote servers. ([\#8324](https://github.com/matrix-org/synapse/issues/8324))
+- Fix `UnboundLocalError` from occuring when appservices send a malformed register request. ([\#8329](https://github.com/matrix-org/synapse/issues/8329))
+- Don't send push notifications to expired user accounts. ([\#8353](https://github.com/matrix-org/synapse/issues/8353))
+- Fix a regression in v1.19.0 with reactivating users through the admin API. ([\#8362](https://github.com/matrix-org/synapse/issues/8362))
+- Fix a bug where during device registration the length of the device name wasn't limited. ([\#8364](https://github.com/matrix-org/synapse/issues/8364))
+- Include `guest_access` in the fields that are checked for null bytes when updating `room_stats_state`. Broke in v1.7.2. ([\#8373](https://github.com/matrix-org/synapse/issues/8373))
+- Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers. ([\#8374](https://github.com/matrix-org/synapse/issues/8374))
+- Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite. ([\#8385](https://github.com/matrix-org/synapse/issues/8385))
+- Fix "Re-starting finished log context" warning when receiving an event we already had over federation. ([\#8398](https://github.com/matrix-org/synapse/issues/8398))
+- Fix incorrect handling of timeouts on outgoing HTTP requests. ([\#8400](https://github.com/matrix-org/synapse/issues/8400))
+- Fix a regression in v1.20.0 in the `synapse_port_db` script regarding the `ui_auth_sessions_ips` table. ([\#8410](https://github.com/matrix-org/synapse/issues/8410))
+- Remove unnecessary 3PID registration check when resetting password via an email address. Bug introduced in v0.34.0rc2. ([\#8414](https://github.com/matrix-org/synapse/issues/8414))
+
+
+Improved Documentation
+----------------------
+
+- Add `/_synapse/client` to the reverse proxy documentation. ([\#8227](https://github.com/matrix-org/synapse/issues/8227))
+- Add note to the reverse proxy settings documentation about disabling Apache's mod_security2. Contributed by Julian Fietkau (@jfietkau). ([\#8375](https://github.com/matrix-org/synapse/issues/8375))
+- Improve description of `server_name` config option in `homserver.yaml`. ([\#8415](https://github.com/matrix-org/synapse/issues/8415))
+
+
+Deprecations and Removals
+-------------------------
+
+- Drop support for `prometheus_client` older than 0.4.0. ([\#8426](https://github.com/matrix-org/synapse/issues/8426))
+
+
+Internal Changes
+----------------
+
+- Fix tests on distros which disable TLSv1.0. Contributed by @danc86. ([\#8208](https://github.com/matrix-org/synapse/issues/8208))
+- Simplify the distributor code to avoid unnecessary work. ([\#8216](https://github.com/matrix-org/synapse/issues/8216))
+- Remove the `populate_stats_process_rooms_2` background job and restore functionality to `populate_stats_process_rooms`. ([\#8243](https://github.com/matrix-org/synapse/issues/8243))
+- Clean up type hints for `PaginationConfig`. ([\#8250](https://github.com/matrix-org/synapse/issues/8250), [\#8282](https://github.com/matrix-org/synapse/issues/8282))
+- Track the latest event for every destination and room for catch-up after federation outage. ([\#8256](https://github.com/matrix-org/synapse/issues/8256))
+- Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`. ([\#8257](https://github.com/matrix-org/synapse/issues/8257))
+- Switch to the JSON implementation from the standard library. ([\#8259](https://github.com/matrix-org/synapse/issues/8259))
+- Add type hints to `synapse.util.async_helpers`. ([\#8260](https://github.com/matrix-org/synapse/issues/8260))
+- Simplify tests that mock asynchronous functions. ([\#8261](https://github.com/matrix-org/synapse/issues/8261))
+- Add type hints to `StreamToken` and `RoomStreamToken` classes. ([\#8279](https://github.com/matrix-org/synapse/issues/8279))
+- Change `StreamToken.room_key` to be a `RoomStreamToken` instance. ([\#8281](https://github.com/matrix-org/synapse/issues/8281))
+- Refactor notifier code to correctly use the max event stream position. ([\#8288](https://github.com/matrix-org/synapse/issues/8288))
+- Use slotted classes where possible. ([\#8296](https://github.com/matrix-org/synapse/issues/8296))
+- Support testing the local Synapse checkout against the [Complement homeserver test suite](https://github.com/matrix-org/complement/). ([\#8317](https://github.com/matrix-org/synapse/issues/8317))
+- Update outdated usages of `metaclass` to python 3 syntax. ([\#8326](https://github.com/matrix-org/synapse/issues/8326))
+- Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. ([\#8330](https://github.com/matrix-org/synapse/issues/8330), [\#8377](https://github.com/matrix-org/synapse/issues/8377))
+- Use the `admin_patterns` helper in additional locations. ([\#8331](https://github.com/matrix-org/synapse/issues/8331))
+- Fix test logging to allow braces in log output. ([\#8335](https://github.com/matrix-org/synapse/issues/8335))
+- Remove `__future__` imports related to Python 2 compatibility. ([\#8337](https://github.com/matrix-org/synapse/issues/8337))
+- Simplify `super()` calls to Python 3 syntax. ([\#8344](https://github.com/matrix-org/synapse/issues/8344))
+- Fix bad merge from `release-v1.20.0` branch to `develop`. ([\#8354](https://github.com/matrix-org/synapse/issues/8354))
+- Factor out a `_send_dummy_event_for_room` method. ([\#8370](https://github.com/matrix-org/synapse/issues/8370))
+- Improve logging of state resolution. ([\#8371](https://github.com/matrix-org/synapse/issues/8371))
+- Add type annotations to `SimpleHttpClient`. ([\#8372](https://github.com/matrix-org/synapse/issues/8372))
+- Refactor ID generators to use `async with` syntax. ([\#8383](https://github.com/matrix-org/synapse/issues/8383))
+- Add `EventStreamPosition` type. ([\#8388](https://github.com/matrix-org/synapse/issues/8388))
+- Create a mechanism for marking tests "logcontext clean". ([\#8399](https://github.com/matrix-org/synapse/issues/8399))
+- A pair of tiny cleanups in the federation request code. ([\#8401](https://github.com/matrix-org/synapse/issues/8401))
+- Add checks on startup that PostgreSQL sequences are consistent with their associated tables. ([\#8402](https://github.com/matrix-org/synapse/issues/8402))
+- Do not include appservice users when calculating the total MAU for a server. ([\#8404](https://github.com/matrix-org/synapse/issues/8404))
+- Typing fixes for `synapse.handlers.federation`. ([\#8422](https://github.com/matrix-org/synapse/issues/8422))
+- Various refactors to simplify stream token handling. ([\#8423](https://github.com/matrix-org/synapse/issues/8423))
+- Make stream token serializing/deserializing async. ([\#8427](https://github.com/matrix-org/synapse/issues/8427))
+
+
+Synapse 1.20.1 (2020-09-24)
+===========================
+
+Bugfixes
+--------
+
+- Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386))
+- Fix a bug introduced in v1.20.0 which caused variables to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394))
+
+
+Synapse 1.20.0 (2020-09-22)
+===========================
+
+No significant changes since v1.20.0rc5.
+
+Removal warning
+---------------
+
+Historically, the [Synapse Admin
+API](https://github.com/matrix-org/synapse/tree/master/docs) has been
+accessible under the `/_matrix/client/api/v1/admin`,
+`/_matrix/client/unstable/admin`, `/_matrix/client/r0/admin` and
+`/_synapse/admin` prefixes. In a future release, we will be dropping support
+for accessing Synapse's Admin API using the `/_matrix/client/*` prefixes. This
+makes it easier for homeserver admins to lock down external access to the Admin
+API endpoints.
+
+Synapse 1.20.0rc5 (2020-09-18)
+==============================
+
+In addition to the below, Synapse 1.20.0rc5 also includes the bug fix that was included in 1.19.3.
+
+Features
+--------
+
+- Add flags to the `/versions` endpoint for whether new rooms default to using E2EE. ([\#8343](https://github.com/matrix-org/synapse/issues/8343))
+
+
+Bugfixes
+--------
+
+- Fix rate limiting of federation `/send` requests. ([\#8342](https://github.com/matrix-org/synapse/issues/8342))
+- Fix a longstanding bug where back pagination over federation could get stuck if it failed to handle a received event. ([\#8349](https://github.com/matrix-org/synapse/issues/8349))
+
+
+Internal Changes
+----------------
+
+- Blacklist [MSC2753](https://github.com/matrix-org/matrix-doc/pull/2753) SyTests until it is implemented. ([\#8285](https://github.com/matrix-org/synapse/issues/8285))
+
+
+Synapse 1.19.3 (2020-09-18)
+===========================
+
+Bugfixes
+--------
+
+- Partially mitigate bug where newly joined servers couldn't get past events in a room when there is a malformed event. ([\#8350](https://github.com/matrix-org/synapse/issues/8350))
+
+
Synapse 1.20.0rc4 (2020-09-16)
==============================
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 062413e925..524f82433d 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -17,9 +17,9 @@ https://help.github.com/articles/using-pull-requests/) to ask us to pull your
changes into our repo.
Some other points to follow:
-
+
* Please base your changes on the `develop` branch.
-
+
* Please follow the [code style requirements](#code-style).
* Please include a [changelog entry](#changelog) with each PR.
@@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools:
```
# Install the dependencies
-pip install -U black flake8 flake8-comprehensions isort
+pip install -e ".[lint]"
# Run the linter script
./scripts-dev/lint.sh
diff --git a/README.rst b/README.rst
index 4a189c8bc4..d609b4b62e 100644
--- a/README.rst
+++ b/README.rst
@@ -1,10 +1,6 @@
-================
-Synapse |shield|
-================
-
-.. |shield| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
- :alt: (get support on #synapse:matrix.org)
- :target: https://matrix.to/#/#synapse:matrix.org
+=========================================================
+Synapse |support| |development| |license| |pypi| |python|
+=========================================================
.. contents::
@@ -290,19 +286,6 @@ Testing with SyTest is recommended for verifying that changes related to the
Client-Server API are functioning correctly. See the `installation instructions
<https://github.com/matrix-org/sytest#installing>`_ for details.
-Building Internal API Documentation
-===================================
-
-Before building internal API documentation install sphinx and
-sphinxcontrib-napoleon::
-
- pip install sphinx
- pip install sphinxcontrib-napoleon
-
-Building internal API documentation::
-
- python setup.py build_sphinx
-
Troubleshooting
===============
@@ -387,3 +370,23 @@ something like the following in their logs::
This is normally caused by a misconfiguration in your reverse-proxy. See
`<docs/reverse_proxy.md>`_ and double-check that your settings are correct.
+
+.. |support| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix
+ :alt: (get support on #synapse:matrix.org)
+ :target: https://matrix.to/#/#synapse:matrix.org
+
+.. |development| image:: https://img.shields.io/matrix/synapse-dev:matrix.org?label=development&logo=matrix
+ :alt: (discuss development on #synapse-dev:matrix.org)
+ :target: https://matrix.to/#/#synapse-dev:matrix.org
+
+.. |license| image:: https://img.shields.io/github/license/matrix-org/synapse
+ :alt: (check license in LICENSE file)
+ :target: LICENSE
+
+.. |pypi| image:: https://img.shields.io/pypi/v/matrix-synapse
+ :alt: (latest version released on PyPi)
+ :target: https://pypi.org/project/matrix-synapse
+
+.. |python| image:: https://img.shields.io/pypi/pyversions/matrix-synapse
+ :alt: (supported python versions)
+ :target: https://pypi.org/project/matrix-synapse
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 49e86e628f..5a68312217 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -75,6 +75,23 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.22.0
+====================
+
+ThirdPartyEventRules breaking changes
+-------------------------------------
+
+This release introduces a backwards-incompatible change to modules making use of
+``ThirdPartyEventRules`` in Synapse. If you make use of a module defined under the
+``third_party_event_rules`` config option, please make sure it is updated to handle
+the below change:
+
+The ``http_client`` argument is no longer passed to modules as they are initialised. Instead,
+modules are expected to make use of the ``http_client`` property on the ``ModuleApi`` class.
+Modules are now passed a ``module_api`` argument during initialisation, which is an instance of
+``ModuleApi``. ``ModuleApi`` instances have a ``http_client`` property which acts the same as
+the ``http_client`` argument previously passed to ``ThirdPartyEventRules`` modules.
+
Upgrading to v1.21.0
====================
diff --git a/changelog.d/7124.bugfix b/changelog.d/7124.bugfix
deleted file mode 100644
index 8fd177780d..0000000000
--- a/changelog.d/7124.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug in the media repository where remote thumbnails with the same size but different crop methods would overwrite each other. Contributed by @deepbluev7.
diff --git a/changelog.d/7658.feature b/changelog.d/7658.feature
new file mode 100644
index 0000000000..fbf345988d
--- /dev/null
+++ b/changelog.d/7658.feature
@@ -0,0 +1 @@
+Add a configuration option for always using the "userinfo endpoint" for OpenID Connect. This fixes support for some identity providers, e.g. GitLab. Contributed by Benjamin Koch.
diff --git a/changelog.d/7796.bugfix b/changelog.d/7796.bugfix
deleted file mode 100644
index 65e5eb42a2..0000000000
--- a/changelog.d/7796.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix inconsistent handling of non-existent push rules, and stop tracking the `enabled` state of removed push rules.
diff --git a/changelog.d/8004.feature b/changelog.d/8004.feature
deleted file mode 100644
index a91b75e0e0..0000000000
--- a/changelog.d/8004.feature
+++ /dev/null
@@ -1 +0,0 @@
-Require the user to confirm that their password should be reset after clicking the email confirmation link.
\ No newline at end of file
diff --git a/changelog.d/8208.misc b/changelog.d/8208.misc
deleted file mode 100644
index e65da88c46..0000000000
--- a/changelog.d/8208.misc
+++ /dev/null
@@ -1 +0,0 @@
-Fix tests on distros which disable TLSv1.0. Contributed by @danc86.
diff --git a/changelog.d/8216.misc b/changelog.d/8216.misc
deleted file mode 100644
index b38911b0e5..0000000000
--- a/changelog.d/8216.misc
+++ /dev/null
@@ -1 +0,0 @@
-Simplify the distributor code to avoid unnecessary work.
diff --git a/changelog.d/8227.doc b/changelog.d/8227.doc
deleted file mode 100644
index 4a43015a83..0000000000
--- a/changelog.d/8227.doc
+++ /dev/null
@@ -1 +0,0 @@
-Add `/_synapse/client` to the reverse proxy documentation.
diff --git a/changelog.d/8230.misc b/changelog.d/8230.misc
deleted file mode 100644
index bf0ba76730..0000000000
--- a/changelog.d/8230.misc
+++ /dev/null
@@ -1 +0,0 @@
-Track the latest event for every destination and room for catch-up after federation outage.
diff --git a/changelog.d/8236.bugfix b/changelog.d/8236.bugfix
deleted file mode 100644
index 6f04871015..0000000000
--- a/changelog.d/8236.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a longstanding bug where files that could not be thumbnailed would result in an Internal Server Error.
diff --git a/changelog.d/8243.misc b/changelog.d/8243.misc
deleted file mode 100644
index f7375d32d3..0000000000
--- a/changelog.d/8243.misc
+++ /dev/null
@@ -1 +0,0 @@
-Remove the 'populate_stats_process_rooms_2' background job and restore functionality to 'populate_stats_process_rooms'.
\ No newline at end of file
diff --git a/changelog.d/8247.misc b/changelog.d/8247.misc
deleted file mode 100644
index 3c27803be4..0000000000
--- a/changelog.d/8247.misc
+++ /dev/null
@@ -1 +0,0 @@
-Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage.
diff --git a/changelog.d/8248.feature b/changelog.d/8248.feature
deleted file mode 100644
index f3c4a74bc7..0000000000
--- a/changelog.d/8248.feature
+++ /dev/null
@@ -1 +0,0 @@
-Consolidate the SSO error template across all configuration.
diff --git a/changelog.d/8250.misc b/changelog.d/8250.misc
deleted file mode 100644
index b6896a9300..0000000000
--- a/changelog.d/8250.misc
+++ /dev/null
@@ -1 +0,0 @@
-Clean up type hints for `PaginationConfig`.
diff --git a/changelog.d/8256.misc b/changelog.d/8256.misc
deleted file mode 100644
index bf0ba76730..0000000000
--- a/changelog.d/8256.misc
+++ /dev/null
@@ -1 +0,0 @@
-Track the latest event for every destination and room for catch-up after federation outage.
diff --git a/changelog.d/8257.misc b/changelog.d/8257.misc
deleted file mode 100644
index 47ac583eb4..0000000000
--- a/changelog.d/8257.misc
+++ /dev/null
@@ -1 +0,0 @@
-Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`.
diff --git a/changelog.d/8258.misc b/changelog.d/8258.misc
deleted file mode 100644
index 3c27803be4..0000000000
--- a/changelog.d/8258.misc
+++ /dev/null
@@ -1 +0,0 @@
-Track the `stream_ordering` of the last successfully-sent event to every destination, so we can use this information to 'catch up' a remote server after an outage.
diff --git a/changelog.d/8259.misc b/changelog.d/8259.misc
deleted file mode 100644
index a26779a664..0000000000
--- a/changelog.d/8259.misc
+++ /dev/null
@@ -1 +0,0 @@
-Switch to the JSON implementation from the standard library.
diff --git a/changelog.d/8260.misc b/changelog.d/8260.misc
deleted file mode 100644
index 164eea8b59..0000000000
--- a/changelog.d/8260.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add type hints to `synapse.util.async_helpers`.
diff --git a/changelog.d/8261.misc b/changelog.d/8261.misc
deleted file mode 100644
index bc91e9375c..0000000000
--- a/changelog.d/8261.misc
+++ /dev/null
@@ -1 +0,0 @@
-Simplify tests that mock asynchronous functions.
diff --git a/changelog.d/8262.bugfix b/changelog.d/8262.bugfix
deleted file mode 100644
index 2b84927de3..0000000000
--- a/changelog.d/8262.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Upgrade canonicaljson to version 1.4.0Â to fix an unicode encoding issue.
diff --git a/changelog.d/8265.bugfix b/changelog.d/8265.bugfix
deleted file mode 100644
index 981a836d21..0000000000
--- a/changelog.d/8265.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix logstanding bug which could lead to incomplete database upgrades on SQLite.
diff --git a/changelog.d/8268.bugfix b/changelog.d/8268.bugfix
deleted file mode 100644
index 4b15a60253..0000000000
--- a/changelog.d/8268.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix stack overflow when stderr is redirected to the logging system, and the logging system encounters an error.
diff --git a/changelog.d/8272.bugfix b/changelog.d/8272.bugfix
deleted file mode 100644
index 532d0e22fe..0000000000
--- a/changelog.d/8272.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix messages over federation being lost until an event is sent into the same room.
diff --git a/changelog.d/8275.feature b/changelog.d/8275.feature
deleted file mode 100644
index 17549c3df3..0000000000
--- a/changelog.d/8275.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add a config option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number.
\ No newline at end of file
diff --git a/changelog.d/8278.bugfix b/changelog.d/8278.bugfix
deleted file mode 100644
index 50e40ca2a9..0000000000
--- a/changelog.d/8278.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug which cause the logging system to report errors, if `DEBUG` was enabled and no `context` filter was applied.
diff --git a/changelog.d/8279.misc b/changelog.d/8279.misc
deleted file mode 100644
index 99f669001f..0000000000
--- a/changelog.d/8279.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add type hints to `StreamToken` and `RoomStreamToken` classes.
diff --git a/changelog.d/8281.misc b/changelog.d/8281.misc
deleted file mode 100644
index 74357120a7..0000000000
--- a/changelog.d/8281.misc
+++ /dev/null
@@ -1 +0,0 @@
-Change `StreamToken.room_key` to be a `RoomStreamToken` instance.
diff --git a/changelog.d/8282.misc b/changelog.d/8282.misc
deleted file mode 100644
index b6896a9300..0000000000
--- a/changelog.d/8282.misc
+++ /dev/null
@@ -1 +0,0 @@
-Clean up type hints for `PaginationConfig`.
diff --git a/changelog.d/8285.misc b/changelog.d/8285.misc
deleted file mode 100644
index 4646664ba1..0000000000
--- a/changelog.d/8285.misc
+++ /dev/null
@@ -1 +0,0 @@
-Blacklist [MSC2753](https://github.com/matrix-org/matrix-doc/pull/2753) SyTests until it is implemented.
\ No newline at end of file
diff --git a/changelog.d/8287.bugfix b/changelog.d/8287.bugfix
deleted file mode 100644
index 839781aa07..0000000000
--- a/changelog.d/8287.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix edge case where push could get delayed for a user until a later event was pushed.
diff --git a/changelog.d/8288.misc b/changelog.d/8288.misc
deleted file mode 100644
index c08a53a5ee..0000000000
--- a/changelog.d/8288.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor notifier code to correctly use the max event stream position.
diff --git a/changelog.d/8292.feature b/changelog.d/8292.feature
new file mode 100644
index 0000000000..6d0335e2c8
--- /dev/null
+++ b/changelog.d/8292.feature
@@ -0,0 +1 @@
+Allow `ThirdPartyEventRules` modules to query and manipulate whether a room is in the public rooms directory.
\ No newline at end of file
diff --git a/changelog.d/8294.feature b/changelog.d/8294.feature
deleted file mode 100644
index b363e929ea..0000000000
--- a/changelog.d/8294.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add experimental support for sharding event persister.
diff --git a/changelog.d/8296.misc b/changelog.d/8296.misc
deleted file mode 100644
index f593a5b347..0000000000
--- a/changelog.d/8296.misc
+++ /dev/null
@@ -1 +0,0 @@
-Use slotted classes where possible.
diff --git a/changelog.d/8305.feature b/changelog.d/8305.feature
deleted file mode 100644
index 862dfdf959..0000000000
--- a/changelog.d/8305.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add the room topic and avatar to the room details admin API.
diff --git a/changelog.d/8312.feature b/changelog.d/8312.feature
new file mode 100644
index 0000000000..222a1b032a
--- /dev/null
+++ b/changelog.d/8312.feature
@@ -0,0 +1 @@
+Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
\ No newline at end of file
diff --git a/changelog.d/8324.bugfix b/changelog.d/8324.bugfix
deleted file mode 100644
index 32788a9284..0000000000
--- a/changelog.d/8324.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix fetching events from remote servers that are malformed.
diff --git a/changelog.d/8369.feature b/changelog.d/8369.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8369.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8380.feature b/changelog.d/8380.feature
new file mode 100644
index 0000000000..05ccea19dc
--- /dev/null
+++ b/changelog.d/8380.feature
@@ -0,0 +1 @@
+Add support for device dehydration ([MSC2697](https://github.com/matrix-org/matrix-doc/pull/2697)).
diff --git a/changelog.d/8407.misc b/changelog.d/8407.misc
new file mode 100644
index 0000000000..d37002d75b
--- /dev/null
+++ b/changelog.d/8407.misc
@@ -0,0 +1 @@
+Add typing information to the device handler.
diff --git a/changelog.d/8432.misc b/changelog.d/8432.misc
new file mode 100644
index 0000000000..01fdad4caf
--- /dev/null
+++ b/changelog.d/8432.misc
@@ -0,0 +1 @@
+Check for unreachable code with mypy.
diff --git a/changelog.d/8433.misc b/changelog.d/8433.misc
new file mode 100644
index 0000000000..05f8b5bbf4
--- /dev/null
+++ b/changelog.d/8433.misc
@@ -0,0 +1 @@
+Add unit test for event persister sharding.
diff --git a/changelog.d/8439.misc b/changelog.d/8439.misc
new file mode 100644
index 0000000000..237cb3b311
--- /dev/null
+++ b/changelog.d/8439.misc
@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.
diff --git a/changelog.d/8443.misc b/changelog.d/8443.misc
new file mode 100644
index 0000000000..633598e6b3
--- /dev/null
+++ b/changelog.d/8443.misc
@@ -0,0 +1 @@
+Configure `public_baseurl` when using demo scripts.
diff --git a/changelog.d/8448.misc b/changelog.d/8448.misc
new file mode 100644
index 0000000000..5ddda1803b
--- /dev/null
+++ b/changelog.d/8448.misc
@@ -0,0 +1 @@
+Add SQL logging on queries that happen during startup.
diff --git a/changelog.d/8450.misc b/changelog.d/8450.misc
new file mode 100644
index 0000000000..4e04c523ab
--- /dev/null
+++ b/changelog.d/8450.misc
@@ -0,0 +1 @@
+Speed up unit tests when using PostgreSQL.
diff --git a/changelog.d/8452.misc b/changelog.d/8452.misc
new file mode 100644
index 0000000000..8288d91c78
--- /dev/null
+++ b/changelog.d/8452.misc
@@ -0,0 +1 @@
+Remove redundant databae loads of stream_ordering for events we already have.
diff --git a/changelog.d/8454.bugfix b/changelog.d/8454.bugfix
new file mode 100644
index 0000000000..c06d490b6f
--- /dev/null
+++ b/changelog.d/8454.bugfix
@@ -0,0 +1 @@
+Fix a longstanding bug where invalid ignored users in account data could break clients.
diff --git a/changelog.d/8457.bugfix b/changelog.d/8457.bugfix
new file mode 100644
index 0000000000..545b06d180
--- /dev/null
+++ b/changelog.d/8457.bugfix
@@ -0,0 +1 @@
+Fix a bug where backfilling a room with an event that was missing the `redacts` field would break.
diff --git a/changelog.d/8458.feature b/changelog.d/8458.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8458.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8461.feature b/changelog.d/8461.feature
new file mode 100644
index 0000000000..3665d670e1
--- /dev/null
+++ b/changelog.d/8461.feature
@@ -0,0 +1 @@
+Change default room version to "6", per [MSC2788](https://github.com/matrix-org/matrix-doc/pull/2788).
diff --git a/changelog.d/8462.doc b/changelog.d/8462.doc
new file mode 100644
index 0000000000..cf84db6db7
--- /dev/null
+++ b/changelog.d/8462.doc
@@ -0,0 +1 @@
+Update the directions for using the manhole with coroutines.
diff --git a/changelog.d/8463.misc b/changelog.d/8463.misc
new file mode 100644
index 0000000000..040c9bb90f
--- /dev/null
+++ b/changelog.d/8463.misc
@@ -0,0 +1 @@
+Reduce inconsistencies between codepaths for membership and non-membership events.
diff --git a/changelog.d/8464.misc b/changelog.d/8464.misc
new file mode 100644
index 0000000000..a552e88f9f
--- /dev/null
+++ b/changelog.d/8464.misc
@@ -0,0 +1 @@
+Combine `SpamCheckerApi` with the more generic `ModuleApi`.
diff --git a/changelog.d/8465.bugfix b/changelog.d/8465.bugfix
new file mode 100644
index 0000000000..73f895b268
--- /dev/null
+++ b/changelog.d/8465.bugfix
@@ -0,0 +1 @@
+Don't attempt to respond to some requests if the client has already disconnected.
\ No newline at end of file
diff --git a/changelog.d/8467.feature b/changelog.d/8467.feature
new file mode 100644
index 0000000000..6d0335e2c8
--- /dev/null
+++ b/changelog.d/8467.feature
@@ -0,0 +1 @@
+Allow `ThirdPartyEventRules` modules to query and manipulate whether a room is in the public rooms directory.
\ No newline at end of file
diff --git a/changelog.d/8468.misc b/changelog.d/8468.misc
new file mode 100644
index 0000000000..32ba991e64
--- /dev/null
+++ b/changelog.d/8468.misc
@@ -0,0 +1 @@
+Additional testing for `ThirdPartyEventRules`.
diff --git a/changelog.d/8474.misc b/changelog.d/8474.misc
new file mode 100644
index 0000000000..65e329a6e3
--- /dev/null
+++ b/changelog.d/8474.misc
@@ -0,0 +1 @@
+Unblacklist some sytests.
diff --git a/changelog.d/8477.misc b/changelog.d/8477.misc
new file mode 100644
index 0000000000..2ee1606b6e
--- /dev/null
+++ b/changelog.d/8477.misc
@@ -0,0 +1 @@
+Include the log level in the phone home stats.
diff --git a/changelog.d/8480.misc b/changelog.d/8480.misc
new file mode 100644
index 0000000000..81633af296
--- /dev/null
+++ b/changelog.d/8480.misc
@@ -0,0 +1 @@
+Remove outdated sphinx documentation, scripts and configuration.
\ No newline at end of file
diff --git a/changelog.d/8486.bugfix b/changelog.d/8486.bugfix
new file mode 100644
index 0000000000..63fc091ba6
--- /dev/null
+++ b/changelog.d/8486.bugfix
@@ -0,0 +1 @@
+Fix incremental sync returning an incorrect `prev_batch` token in timeline section, which when used to paginate returned events that were included in the incremental sync. Broken since v0.16.0.
diff --git a/changelog.d/8492.misc b/changelog.d/8492.misc
new file mode 100644
index 0000000000..a344aee791
--- /dev/null
+++ b/changelog.d/8492.misc
@@ -0,0 +1 @@
+Clarify error message when plugin config parsers raise an error.
diff --git a/changelog.d/8493.doc b/changelog.d/8493.doc
new file mode 100644
index 0000000000..26797cd99e
--- /dev/null
+++ b/changelog.d/8493.doc
@@ -0,0 +1 @@
+Improve readme by adding new shield.io badges.
diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py
index dfc1d294dc..ab1e1f1f4c 100755
--- a/contrib/cmdclient/console.py
+++ b/contrib/cmdclient/console.py
@@ -15,8 +15,6 @@
# limitations under the License.
""" Starts a synapse client console. """
-from __future__ import print_function
-
import argparse
import cmd
import getpass
diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py
index cd3260b27d..345120b612 100644
--- a/contrib/cmdclient/http.py
+++ b/contrib/cmdclient/http.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import print_function
-
import json
import urllib
from pprint import pformat
diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py
index de33fac1c7..fdbac087bd 100644
--- a/contrib/graph/graph.py
+++ b/contrib/graph/graph.py
@@ -1,5 +1,3 @@
-from __future__ import print_function
-
import argparse
import cgi
import datetime
diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py
index 91db98e7ef..dd0c19368b 100644
--- a/contrib/graph/graph3.py
+++ b/contrib/graph/graph3.py
@@ -1,5 +1,3 @@
-from __future__ import print_function
-
import argparse
import cgi
import datetime
diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py
index 69aa74bd34..b3de468687 100644
--- a/contrib/jitsimeetbridge/jitsimeetbridge.py
+++ b/contrib/jitsimeetbridge/jitsimeetbridge.py
@@ -10,8 +10,6 @@ the bridge.
Requires:
npm install jquery jsdom
"""
-from __future__ import print_function
-
import json
import subprocess
import time
diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py
index 372dbd9e4f..f8e0c732fb 100755
--- a/contrib/scripts/kick_users.py
+++ b/contrib/scripts/kick_users.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python
-from __future__ import print_function
import json
import sys
@@ -8,11 +7,6 @@ from argparse import ArgumentParser
import requests
-try:
- raw_input
-except NameError: # Python 3
- raw_input = input
-
def _mkurl(template, kws):
for key in kws:
@@ -58,7 +52,7 @@ def main(hs, room_id, access_token, user_id_prefix, why):
print("The following user IDs will be kicked from %s" % room_name)
for uid in kick_list:
print(uid)
- doit = raw_input("Continue? [Y]es\n")
+ doit = input("Continue? [Y]es\n")
if len(doit) > 0 and doit.lower() == "y":
print("Kicking members...")
# encode them all
diff --git a/debian/changelog b/debian/changelog
index bb7c175ada..264ef9ce7c 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,8 +1,24 @@
-matrix-synapse-py3 (1.20.0ubuntu1) UNRELEASED; urgency=medium
+matrix-synapse-py3 (1.20.1) stable; urgency=medium
+ * New synapse release 1.20.1.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 24 Sep 2020 16:25:22 +0100
+
+matrix-synapse-py3 (1.20.0) stable; urgency=medium
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.20.0.
+
+ [ Dexter Chua ]
* Use Type=notify in systemd service
- -- Dexter Chua <dec41@srcf.net> Wed, 26 Aug 2020 12:41:36 +0000
+ -- Synapse Packaging team <packages@matrix.org> Tue, 22 Sep 2020 15:19:32 +0100
+
+matrix-synapse-py3 (1.19.3) stable; urgency=medium
+
+ * New synapse release 1.19.3.
+
+ -- Synapse Packaging team <packages@matrix.org> Fri, 18 Sep 2020 14:59:30 +0100
matrix-synapse-py3 (1.19.2) stable; urgency=medium
diff --git a/demo/start.sh b/demo/start.sh
index 83396e5c33..f6b5ea137f 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -30,6 +30,8 @@ for port in 8080 8081 8082; do
if ! grep -F "Customisation made by demo/start.sh" -q $DIR/etc/$port.config; then
printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config
+ echo "public_baseurl: http://localhost:$port/" >> $DIR/etc/$port.config
+
echo 'enable_registration: true' >> $DIR/etc/$port.config
# Warning, this heredoc depends on the interaction of tabs and spaces. Please don't
diff --git a/docs/admin_api/event_reports.rst b/docs/admin_api/event_reports.rst
new file mode 100644
index 0000000000..461be01230
--- /dev/null
+++ b/docs/admin_api/event_reports.rst
@@ -0,0 +1,129 @@
+Show reported events
+====================
+
+This API returns information about reported events.
+
+The api is::
+
+ GET /_synapse/admin/v1/event_reports?from=0&limit=10
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+It returns a JSON body like the following:
+
+.. code:: jsonc
+
+ {
+ "event_reports": [
+ {
+ "content": {
+ "reason": "foo",
+ "score": -100
+ },
+ "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY",
+ "event_json": {
+ "auth_events": [
+ "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M",
+ "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws"
+ ],
+ "content": {
+ "body": "matrix.org: This Week in Matrix",
+ "format": "org.matrix.custom.html",
+ "formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>",
+ "msgtype": "m.notice"
+ },
+ "depth": 546,
+ "hashes": {
+ "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw"
+ },
+ "origin": "matrix.org",
+ "origin_server_ts": 1592291711430,
+ "prev_events": [
+ "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M"
+ ],
+ "prev_state": [],
+ "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org",
+ "sender": "@foobar:matrix.org",
+ "signatures": {
+ "matrix.org": {
+ "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg"
+ }
+ },
+ "type": "m.room.message",
+ "unsigned": {
+ "age_ts": 1592291711430,
+ }
+ },
+ "id": 2,
+ "reason": "foo",
+ "received_ts": 1570897107409,
+ "room_alias": "#alias1:matrix.org",
+ "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org",
+ "sender": "@foobar:matrix.org",
+ "user_id": "@foo:matrix.org"
+ },
+ {
+ "content": {
+ "reason": "bar",
+ "score": -100
+ },
+ "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I",
+ "event_json": {
+ // hidden items
+ // see above
+ },
+ "id": 3,
+ "reason": "bar",
+ "received_ts": 1598889612059,
+ "room_alias": "#alias2:matrix.org",
+ "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org",
+ "sender": "@foobar:matrix.org",
+ "user_id": "@bar:matrix.org"
+ }
+ ],
+ "next_token": 2,
+ "total": 4
+ }
+
+To paginate, check for ``next_token`` and if present, call the endpoint again
+with ``from`` set to the value of ``next_token``. This will return a new page.
+
+If the endpoint does not return a ``next_token`` then there are no more
+reports to paginate through.
+
+**URL parameters:**
+
+- ``limit``: integer - Is optional but is used for pagination,
+ denoting the maximum number of items to return in this call. Defaults to ``100``.
+- ``from``: integer - Is optional but used for pagination,
+ denoting the offset in the returned results. This should be treated as an opaque value and
+ not explicitly set to anything other than the return value of ``next_token`` from a previous call.
+ Defaults to ``0``.
+- ``dir``: string - Direction of event report order. Whether to fetch the most recent first (``b``) or the
+ oldest first (``f``). Defaults to ``b``.
+- ``user_id``: string - Is optional and filters to only return users with user IDs that contain this value.
+ This is the user who reported the event and wrote the reason.
+- ``room_id``: string - Is optional and filters to only return rooms with room IDs that contain this value.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``id``: integer - ID of event report.
+- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent.
+- ``room_id``: string - The ID of the room in which the event being reported is located.
+- ``event_id``: string - The ID of the reported event.
+- ``user_id``: string - This is the user who reported the event and wrote the reason.
+- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank.
+- ``content``: object - Content of reported event.
+
+ - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank.
+ - ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive".
+
+- ``sender``: string - This is the ID of the user who sent the original message/event that was reported.
+- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set.
+- ``event_json``: object - Details of the original event that was reported.
+- ``next_token``: integer - Indication for pagination. See above.
+- ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``).
+
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index e21c78a9c6..7ca902faba 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -304,6 +304,43 @@ To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
+List room memberships of an user
+================================
+Gets a list of all ``room_id`` that a specific ``user_id`` is member.
+
+The API is::
+
+ GET /_synapse/admin/v1/users/<user_id>/joined_rooms
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+A response body like the following is returned:
+
+.. code:: json
+
+ {
+ "joined_rooms": [
+ "!DuGcnbhHGaSZQoNQR:matrix.org",
+ "!ZtSaPCawyWtxfWiIy:matrix.org"
+ ],
+ "total": 2
+ }
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - fully qualified: for example, ``@user:server.com``.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``joined_rooms`` - An array of ``room_id``.
+- ``total`` - Number of rooms.
+
+
User devices
============
diff --git a/docs/code_style.md b/docs/code_style.md
index 6ef6f80290..f6c825d7d4 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -64,8 +64,6 @@ save as it takes a while and is very resource intensive.
- Use underscores for functions and variables.
- **Docstrings**: should follow the [google code
style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings).
- This is so that we can generate documentation with
- [sphinx](http://sphinxcontrib-napoleon.readthedocs.org/en/latest/).
See the
[examples](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
in the sphinx documentation.
diff --git a/docs/manhole.md b/docs/manhole.md
index 7375f5ad46..75b6ae40e0 100644
--- a/docs/manhole.md
+++ b/docs/manhole.md
@@ -35,9 +35,12 @@ This gives a Python REPL in which `hs` gives access to the
`synapse.server.HomeServer` object - which in turn gives access to many other
parts of the process.
+Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`.
+
As a simple example, retrieving an event from the database:
-```
->>> hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')
+```pycon
+>>> from twisted.internet import defer
+>>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org'))
<Deferred at 0x7ff253fc6998 current result: <FrozenEvent event_id='$1416420717069yeQaw:matrix.org', type='m.room.create', state_key=''>>
```
diff --git a/docs/openid.md b/docs/openid.md
index 70b37f858b..4873681999 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -238,13 +238,36 @@ Synapse config:
```yaml
oidc_config:
- enabled: true
- issuer: "https://id.twitch.tv/oauth2/"
- client_id: "your-client-id" # TO BE FILLED
- client_secret: "your-client-secret" # TO BE FILLED
- client_auth_method: "client_secret_post"
- user_mapping_provider:
- config:
- localpart_template: '{{ user.preferred_username }}'
- display_name_template: '{{ user.name }}'
+ enabled: true
+ issuer: "https://id.twitch.tv/oauth2/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ user_mapping_provider:
+ config:
+ localpart_template: "{{ user.preferred_username }}"
+ display_name_template: "{{ user.name }}"
+```
+
+### GitLab
+
+1. Create a [new application](https://gitlab.com/profile/applications).
+2. Add the `read_user` and `openid` scopes.
+3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
+
+Synapse config:
+
+```yaml
+oidc_config:
+ enabled: true
+ issuer: "https://gitlab.com/"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ client_auth_method: "client_secret_post"
+ scopes: ["openid", "read_user"]
+ user_profile_method: "userinfo_endpoint"
+ user_mapping_provider:
+ config:
+ localpart_template: '{{ user.nickname }}'
+ display_name_template: '{{ user.name }}'
```
diff --git a/docs/postgres.md b/docs/postgres.md
index e71a1975d8..c30cc1fd8c 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -106,6 +106,17 @@ Note that the above may fail with an error about duplicate rows if corruption
has already occurred, and such duplicate rows will need to be manually removed.
+## Fixing inconsistent sequences error
+
+Synapse uses Postgres sequences to generate IDs for various tables. A sequence
+and associated table can get out of sync if, for example, Synapse has been
+downgraded and then upgraded again.
+
+To fix the issue shut down Synapse (including any and all workers) and run the
+SQL command included in the error message. Once done Synapse should start
+successfully.
+
+
## Tuning Postgres
The default settings should be fine for most deployments. For larger
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index edd109fa7b..46d8f35771 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -121,6 +121,14 @@ example.com:8448 {
**NOTE**: ensure the `nocanon` options are included.
+**NOTE 2**: It appears that Synapse is currently incompatible with the ModSecurity module for Apache (`mod_security2`). If you need it enabled for other services on your web server, you can disable it for Synapse's two VirtualHosts by including the following lines before each of the two `</VirtualHost>` above:
+
+```
+<IfModule security2_module>
+ SecRuleEngine off
+</IfModule>
+```
+
### HAProxy
```
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index fb04ff283d..bb64662e28 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -33,10 +33,23 @@
## Server ##
-# The domain name of the server, with optional explicit port.
-# This is used by remote servers to connect to this server,
-# e.g. matrix.org, localhost:8080, etc.
-# This is also the last part of your UserID.
+# The public-facing domain of the server
+#
+# The server_name name will appear at the end of usernames and room addresses
+# created on this server. For example if the server_name was example.com,
+# usernames on this server would be in the format @user:example.com
+#
+# In most cases you should avoid using a matrix specific subdomain such as
+# matrix.example.com or synapse.example.com as the server_name for the same
+# reasons you wouldn't use user@email.example.com as your email address.
+# See https://github.com/matrix-org/synapse/blob/master/docs/delegate.md
+# for information on how to host Synapse on a subdomain while preserving
+# a clean server_name.
+#
+# The server_name cannot be changed later so it is important to
+# configure this correctly before you start Synapse. It should be all
+# lowercase and may contain an explicit port.
+# Examples: matrix.org, localhost:8080
#
server_name: "SERVERNAME"
@@ -106,7 +119,7 @@ pid_file: DATADIR/homeserver.pid
# For example, for room version 1, default_room_version should be set
# to "1".
#
-#default_room_version: "5"
+#default_room_version: "6"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
@@ -616,6 +629,7 @@ acme:
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
+## Federation ##
# Restrict federation to the following whitelist of domains.
# N.B. we recommend also firewalling your federation listener to limit
@@ -649,6 +663,17 @@ federation_ip_range_blacklist:
- 'fe80::/64'
- 'fc00::/7'
+# Report prometheus metrics on the age of PDUs being sent to and received from
+# the following domains. This can be used to give an idea of "delay" on inbound
+# and outbound federation, though be aware that any delay can be due to problems
+# at either end or with the intermediate network.
+#
+# By default, no domains are monitored in this way.
+#
+#federation_metrics_domains:
+# - matrix.org
+# - example.com
+
## Caching ##
@@ -1689,6 +1714,19 @@ oidc_config:
#
#skip_verification: true
+ # Whether to fetch the user profile from the userinfo endpoint. Valid
+ # values are: "auto" or "userinfo_endpoint".
+ #
+ # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
+ # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+ #
+ #user_profile_method: "userinfo_endpoint"
+
+ # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
+ # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
+ #
+ #allow_existing_users: true
+
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
@@ -1730,6 +1768,14 @@ oidc_config:
#
#display_name_template: "{{ user.given_name }} {{ user.last_name }}"
+ # Jinja2 templates for extra attributes to send back to the client during
+ # login.
+ #
+ # Note that these are non-standard and clients will ignore them without modifications.
+ #
+ #extra_attributes:
+ #birthdate: "{{ user.birthdate }}"
+
# Enable CAS for registration and login.
@@ -2458,6 +2504,11 @@ opentracing:
# events: worker1
# typing: worker1
+# The worker that is used to run background tasks (e.g. cleaning up expired
+# data). If not provided this defaults to the main process.
+#
+#run_background_tasks_on: worker1
+
# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index eb10e115f9..7fc08f1b70 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -11,7 +11,7 @@ able to be imported by the running Synapse.
The Python class is instantiated with two objects:
* Any configuration (see below).
-* An instance of `synapse.spam_checker_api.SpamCheckerApi`.
+* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
@@ -26,11 +26,8 @@ well as some specific methods:
The details of the each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
-The `SpamCheckerApi` class provides a way for the custom spam checker class to
-call back into the homeserver internals. It currently implements the following
-methods:
-
-* `get_state_events_in_room`
+The `ModuleApi` class provides a way for the custom spam checker class to
+call back into the homeserver internals.
### Example
diff --git a/docs/sphinx/README.rst b/docs/sphinx/README.rst
deleted file mode 100644
index a7ab7c5500..0000000000
--- a/docs/sphinx/README.rst
+++ /dev/null
@@ -1 +0,0 @@
-TODO: how (if at all) is this actually maintained?
diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py
deleted file mode 100644
index ca4b879526..0000000000
--- a/docs/sphinx/conf.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Synapse documentation build configuration file, created by
-# sphinx-quickstart on Tue Jun 10 17:31:02 2014.
-#
-# This file is execfile()d with the current directory set to its
-# containing dir.
-#
-# Note that not all possible configuration values are present in this
-# autogenerated file.
-#
-# All configuration values have a default; values that are commented out
-# serve to show the default.
-
-import sys
-import os
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath(".."))
-
-# -- General configuration ------------------------------------------------
-
-# If your documentation needs a minimal Sphinx version, state it here.
-# needs_sphinx = '1.0'
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
-# ones.
-extensions = [
- "sphinx.ext.autodoc",
- "sphinx.ext.intersphinx",
- "sphinx.ext.coverage",
- "sphinx.ext.ifconfig",
- "sphinxcontrib.napoleon",
-]
-
-# Add any paths that contain templates here, relative to this directory.
-templates_path = ["_templates"]
-
-# The suffix of source filenames.
-source_suffix = ".rst"
-
-# The encoding of source files.
-# source_encoding = 'utf-8-sig'
-
-# The master toctree document.
-master_doc = "index"
-
-# General information about the project.
-project = "Synapse"
-copyright = (
- "Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd"
-)
-
-# The version info for the project you're documenting, acts as replacement for
-# |version| and |release|, also used in various other places throughout the
-# built documents.
-#
-# The short X.Y version.
-version = "1.0"
-# The full version, including alpha/beta/rc tags.
-release = "1.0"
-
-# The language for content autogenerated by Sphinx. Refer to documentation
-# for a list of supported languages.
-# language = None
-
-# There are two options for replacing |today|: either, you set today to some
-# non-false value, then it is used:
-# today = ''
-# Else, today_fmt is used as the format for a strftime call.
-# today_fmt = '%B %d, %Y'
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-exclude_patterns = ["_build"]
-
-# The reST default role (used for this markup: `text`) to use for all
-# documents.
-# default_role = None
-
-# If true, '()' will be appended to :func: etc. cross-reference text.
-# add_function_parentheses = True
-
-# If true, the current module name will be prepended to all description
-# unit titles (such as .. function::).
-# add_module_names = True
-
-# If true, sectionauthor and moduleauthor directives will be shown in the
-# output. They are ignored by default.
-# show_authors = False
-
-# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = "sphinx"
-
-# A list of ignored prefixes for module index sorting.
-# modindex_common_prefix = []
-
-# If true, keep warnings as "system message" paragraphs in the built documents.
-# keep_warnings = False
-
-
-# -- Options for HTML output ----------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-html_theme = "default"
-
-# Theme options are theme-specific and customize the look and feel of a theme
-# further. For a list of options available for each theme, see the
-# documentation.
-# html_theme_options = {}
-
-# Add any paths that contain custom themes here, relative to this directory.
-# html_theme_path = []
-
-# The name for this set of Sphinx documents. If None, it defaults to
-# "<project> v<release> documentation".
-# html_title = None
-
-# A shorter title for the navigation bar. Default is the same as html_title.
-# html_short_title = None
-
-# The name of an image file (relative to this directory) to place at the top
-# of the sidebar.
-# html_logo = None
-
-# The name of an image file (within the static path) to use as favicon of the
-# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
-# pixels large.
-# html_favicon = None
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ["_static"]
-
-# Add any extra paths that contain custom files (such as robots.txt or
-# .htaccess) here, relative to this directory. These files are copied
-# directly to the root of the documentation.
-# html_extra_path = []
-
-# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
-# using the given strftime format.
-# html_last_updated_fmt = '%b %d, %Y'
-
-# If true, SmartyPants will be used to convert quotes and dashes to
-# typographically correct entities.
-# html_use_smartypants = True
-
-# Custom sidebar templates, maps document names to template names.
-# html_sidebars = {}
-
-# Additional templates that should be rendered to pages, maps page names to
-# template names.
-# html_additional_pages = {}
-
-# If false, no module index is generated.
-# html_domain_indices = True
-
-# If false, no index is generated.
-# html_use_index = True
-
-# If true, the index is split into individual pages for each letter.
-# html_split_index = False
-
-# If true, links to the reST sources are added to the pages.
-# html_show_sourcelink = True
-
-# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
-# html_show_sphinx = True
-
-# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
-# html_show_copyright = True
-
-# If true, an OpenSearch description file will be output, and all pages will
-# contain a <link> tag referring to it. The value of this option must be the
-# base URL from which the finished HTML is served.
-# html_use_opensearch = ''
-
-# This is the file name suffix for HTML files (e.g. ".xhtml").
-# html_file_suffix = None
-
-# Output file base name for HTML help builder.
-htmlhelp_basename = "Synapsedoc"
-
-
-# -- Options for LaTeX output ---------------------------------------------
-
-latex_elements = {
- # The paper size ('letterpaper' or 'a4paper').
- #'papersize': 'letterpaper',
- # The font size ('10pt', '11pt' or '12pt').
- #'pointsize': '10pt',
- # Additional stuff for the LaTeX preamble.
- #'preamble': '',
-}
-
-# Grouping the document tree into LaTeX files. List of tuples
-# (source start file, target name, title,
-# author, documentclass [howto, manual, or own class]).
-latex_documents = [("index", "Synapse.tex", "Synapse Documentation", "TNG", "manual")]
-
-# The name of an image file (relative to this directory) to place at the top of
-# the title page.
-# latex_logo = None
-
-# For "manual" documents, if this is true, then toplevel headings are parts,
-# not chapters.
-# latex_use_parts = False
-
-# If true, show page references after internal links.
-# latex_show_pagerefs = False
-
-# If true, show URL addresses after external links.
-# latex_show_urls = False
-
-# Documents to append as an appendix to all manuals.
-# latex_appendices = []
-
-# If false, no module index is generated.
-# latex_domain_indices = True
-
-
-# -- Options for manual page output ---------------------------------------
-
-# One entry per manual page. List of tuples
-# (source start file, name, description, authors, manual section).
-man_pages = [("index", "synapse", "Synapse Documentation", ["TNG"], 1)]
-
-# If true, show URL addresses after external links.
-# man_show_urls = False
-
-
-# -- Options for Texinfo output -------------------------------------------
-
-# Grouping the document tree into Texinfo files. List of tuples
-# (source start file, target name, title, author,
-# dir menu entry, description, category)
-texinfo_documents = [
- (
- "index",
- "Synapse",
- "Synapse Documentation",
- "TNG",
- "Synapse",
- "One line description of project.",
- "Miscellaneous",
- )
-]
-
-# Documents to append as an appendix to all manuals.
-# texinfo_appendices = []
-
-# If false, no module index is generated.
-# texinfo_domain_indices = True
-
-# How to display URL addresses: 'footnote', 'no', or 'inline'.
-# texinfo_show_urls = 'footnote'
-
-# If true, do not generate a @detailmenu in the "Top" node's menu.
-# texinfo_no_detailmenu = False
-
-
-# Example configuration for intersphinx: refer to the Python standard library.
-intersphinx_mapping = {"http://docs.python.org/": None}
-
-napoleon_include_special_with_doc = True
-napoleon_use_ivar = True
diff --git a/docs/sphinx/index.rst b/docs/sphinx/index.rst
deleted file mode 100644
index 76a4c0c7bf..0000000000
--- a/docs/sphinx/index.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-.. Synapse documentation master file, created by
- sphinx-quickstart on Tue Jun 10 17:31:02 2014.
- You can adapt this file completely to your liking, but it should at least
- contain the root `toctree` directive.
-
-Welcome to Synapse's documentation!
-===================================
-
-Contents:
-
-.. toctree::
- synapse
-
-Indices and tables
-==================
-
-* :ref:`genindex`
-* :ref:`modindex`
-* :ref:`search`
-
diff --git a/docs/sphinx/modules.rst b/docs/sphinx/modules.rst
deleted file mode 100644
index 1c7f70bd13..0000000000
--- a/docs/sphinx/modules.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse
-=======
-
-.. toctree::
- :maxdepth: 4
-
- synapse
diff --git a/docs/sphinx/synapse.api.auth.rst b/docs/sphinx/synapse.api.auth.rst
deleted file mode 100644
index 931eb59836..0000000000
--- a/docs/sphinx/synapse.api.auth.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.auth module
-=======================
-
-.. automodule:: synapse.api.auth
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.constants.rst b/docs/sphinx/synapse.api.constants.rst
deleted file mode 100644
index a1e3c47f68..0000000000
--- a/docs/sphinx/synapse.api.constants.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.constants module
-============================
-
-.. automodule:: synapse.api.constants
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.dbobjects.rst b/docs/sphinx/synapse.api.dbobjects.rst
deleted file mode 100644
index e9d31167e0..0000000000
--- a/docs/sphinx/synapse.api.dbobjects.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.dbobjects module
-============================
-
-.. automodule:: synapse.api.dbobjects
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.errors.rst b/docs/sphinx/synapse.api.errors.rst
deleted file mode 100644
index f1c6881478..0000000000
--- a/docs/sphinx/synapse.api.errors.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.errors module
-=========================
-
-.. automodule:: synapse.api.errors
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.event_stream.rst b/docs/sphinx/synapse.api.event_stream.rst
deleted file mode 100644
index 9291cb2dbc..0000000000
--- a/docs/sphinx/synapse.api.event_stream.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.event_stream module
-===============================
-
-.. automodule:: synapse.api.event_stream
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.events.factory.rst b/docs/sphinx/synapse.api.events.factory.rst
deleted file mode 100644
index 2e71ff6070..0000000000
--- a/docs/sphinx/synapse.api.events.factory.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.events.factory module
-=================================
-
-.. automodule:: synapse.api.events.factory
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.events.room.rst b/docs/sphinx/synapse.api.events.room.rst
deleted file mode 100644
index 6cd5998599..0000000000
--- a/docs/sphinx/synapse.api.events.room.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.events.room module
-==============================
-
-.. automodule:: synapse.api.events.room
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.events.rst b/docs/sphinx/synapse.api.events.rst
deleted file mode 100644
index b762da55ee..0000000000
--- a/docs/sphinx/synapse.api.events.rst
+++ /dev/null
@@ -1,18 +0,0 @@
-synapse.api.events package
-==========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.events.factory
- synapse.api.events.room
-
-Module contents
----------------
-
-.. automodule:: synapse.api.events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.events.rst b/docs/sphinx/synapse.api.handlers.events.rst
deleted file mode 100644
index d2e1b54ac0..0000000000
--- a/docs/sphinx/synapse.api.handlers.events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.events module
-==================================
-
-.. automodule:: synapse.api.handlers.events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.factory.rst b/docs/sphinx/synapse.api.handlers.factory.rst
deleted file mode 100644
index b04a93f740..0000000000
--- a/docs/sphinx/synapse.api.handlers.factory.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.factory module
-===================================
-
-.. automodule:: synapse.api.handlers.factory
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.federation.rst b/docs/sphinx/synapse.api.handlers.federation.rst
deleted file mode 100644
index 61a6542210..0000000000
--- a/docs/sphinx/synapse.api.handlers.federation.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.federation module
-======================================
-
-.. automodule:: synapse.api.handlers.federation
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.register.rst b/docs/sphinx/synapse.api.handlers.register.rst
deleted file mode 100644
index 388f144eca..0000000000
--- a/docs/sphinx/synapse.api.handlers.register.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.register module
-====================================
-
-.. automodule:: synapse.api.handlers.register
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.room.rst b/docs/sphinx/synapse.api.handlers.room.rst
deleted file mode 100644
index 8ca156c7ff..0000000000
--- a/docs/sphinx/synapse.api.handlers.room.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.handlers.room module
-================================
-
-.. automodule:: synapse.api.handlers.room
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.handlers.rst b/docs/sphinx/synapse.api.handlers.rst
deleted file mode 100644
index e84f563fcb..0000000000
--- a/docs/sphinx/synapse.api.handlers.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-synapse.api.handlers package
-============================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.handlers.events
- synapse.api.handlers.factory
- synapse.api.handlers.federation
- synapse.api.handlers.register
- synapse.api.handlers.room
-
-Module contents
----------------
-
-.. automodule:: synapse.api.handlers
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.notifier.rst b/docs/sphinx/synapse.api.notifier.rst
deleted file mode 100644
index 631b42a497..0000000000
--- a/docs/sphinx/synapse.api.notifier.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.notifier module
-===========================
-
-.. automodule:: synapse.api.notifier
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.register_events.rst b/docs/sphinx/synapse.api.register_events.rst
deleted file mode 100644
index 79ad4ce211..0000000000
--- a/docs/sphinx/synapse.api.register_events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.register_events module
-==================================
-
-.. automodule:: synapse.api.register_events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.room_events.rst b/docs/sphinx/synapse.api.room_events.rst
deleted file mode 100644
index bead1711f5..0000000000
--- a/docs/sphinx/synapse.api.room_events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.room_events module
-==============================
-
-.. automodule:: synapse.api.room_events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.rst b/docs/sphinx/synapse.api.rst
deleted file mode 100644
index f4d39ff331..0000000000
--- a/docs/sphinx/synapse.api.rst
+++ /dev/null
@@ -1,30 +0,0 @@
-synapse.api package
-===================
-
-Subpackages
------------
-
-.. toctree::
-
- synapse.api.events
- synapse.api.handlers
- synapse.api.streams
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.auth
- synapse.api.constants
- synapse.api.errors
- synapse.api.notifier
- synapse.api.storage
-
-Module contents
----------------
-
-.. automodule:: synapse.api
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.server.rst b/docs/sphinx/synapse.api.server.rst
deleted file mode 100644
index b01600235e..0000000000
--- a/docs/sphinx/synapse.api.server.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.server module
-=========================
-
-.. automodule:: synapse.api.server
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.storage.rst b/docs/sphinx/synapse.api.storage.rst
deleted file mode 100644
index afa40685c4..0000000000
--- a/docs/sphinx/synapse.api.storage.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.storage module
-==========================
-
-.. automodule:: synapse.api.storage
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.stream.rst b/docs/sphinx/synapse.api.stream.rst
deleted file mode 100644
index 0d5e3f01bf..0000000000
--- a/docs/sphinx/synapse.api.stream.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.stream module
-=========================
-
-.. automodule:: synapse.api.stream
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.streams.event.rst b/docs/sphinx/synapse.api.streams.event.rst
deleted file mode 100644
index 2ac45a35c8..0000000000
--- a/docs/sphinx/synapse.api.streams.event.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.api.streams.event module
-================================
-
-.. automodule:: synapse.api.streams.event
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.api.streams.rst b/docs/sphinx/synapse.api.streams.rst
deleted file mode 100644
index 72eb205caf..0000000000
--- a/docs/sphinx/synapse.api.streams.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-synapse.api.streams package
-===========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.api.streams.event
-
-Module contents
----------------
-
-.. automodule:: synapse.api.streams
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.app.homeserver.rst b/docs/sphinx/synapse.app.homeserver.rst
deleted file mode 100644
index 54b93da8fe..0000000000
--- a/docs/sphinx/synapse.app.homeserver.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.app.homeserver module
-=============================
-
-.. automodule:: synapse.app.homeserver
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.app.rst b/docs/sphinx/synapse.app.rst
deleted file mode 100644
index 4535b79827..0000000000
--- a/docs/sphinx/synapse.app.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-synapse.app package
-===================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.app.homeserver
-
-Module contents
----------------
-
-.. automodule:: synapse.app
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.db.rst b/docs/sphinx/synapse.db.rst
deleted file mode 100644
index 83df6c03db..0000000000
--- a/docs/sphinx/synapse.db.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-synapse.db package
-==================
-
-Module contents
----------------
-
-.. automodule:: synapse.db
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.handler.rst b/docs/sphinx/synapse.federation.handler.rst
deleted file mode 100644
index 5597f5c46d..0000000000
--- a/docs/sphinx/synapse.federation.handler.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.handler module
-=================================
-
-.. automodule:: synapse.federation.handler
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.messaging.rst b/docs/sphinx/synapse.federation.messaging.rst
deleted file mode 100644
index 4bbaabf3ef..0000000000
--- a/docs/sphinx/synapse.federation.messaging.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.messaging module
-===================================
-
-.. automodule:: synapse.federation.messaging
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.pdu_codec.rst b/docs/sphinx/synapse.federation.pdu_codec.rst
deleted file mode 100644
index 8f0b15a63c..0000000000
--- a/docs/sphinx/synapse.federation.pdu_codec.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.pdu_codec module
-===================================
-
-.. automodule:: synapse.federation.pdu_codec
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.persistence.rst b/docs/sphinx/synapse.federation.persistence.rst
deleted file mode 100644
index db7ab8ade1..0000000000
--- a/docs/sphinx/synapse.federation.persistence.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.persistence module
-=====================================
-
-.. automodule:: synapse.federation.persistence
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.replication.rst b/docs/sphinx/synapse.federation.replication.rst
deleted file mode 100644
index 49e26e0928..0000000000
--- a/docs/sphinx/synapse.federation.replication.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.replication module
-=====================================
-
-.. automodule:: synapse.federation.replication
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.rst b/docs/sphinx/synapse.federation.rst
deleted file mode 100644
index 7240c7901b..0000000000
--- a/docs/sphinx/synapse.federation.rst
+++ /dev/null
@@ -1,22 +0,0 @@
-synapse.federation package
-==========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.federation.handler
- synapse.federation.pdu_codec
- synapse.federation.persistence
- synapse.federation.replication
- synapse.federation.transport
- synapse.federation.units
-
-Module contents
----------------
-
-.. automodule:: synapse.federation
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.transport.rst b/docs/sphinx/synapse.federation.transport.rst
deleted file mode 100644
index 877956b3c9..0000000000
--- a/docs/sphinx/synapse.federation.transport.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.transport module
-===================================
-
-.. automodule:: synapse.federation.transport
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.federation.units.rst b/docs/sphinx/synapse.federation.units.rst
deleted file mode 100644
index 8f9212b07d..0000000000
--- a/docs/sphinx/synapse.federation.units.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.federation.units module
-===============================
-
-.. automodule:: synapse.federation.units
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.rst b/docs/sphinx/synapse.persistence.rst
deleted file mode 100644
index 37c0c23720..0000000000
--- a/docs/sphinx/synapse.persistence.rst
+++ /dev/null
@@ -1,19 +0,0 @@
-synapse.persistence package
-===========================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.persistence.service
- synapse.persistence.tables
- synapse.persistence.transactions
-
-Module contents
----------------
-
-.. automodule:: synapse.persistence
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.service.rst b/docs/sphinx/synapse.persistence.service.rst
deleted file mode 100644
index 3514d3c76f..0000000000
--- a/docs/sphinx/synapse.persistence.service.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.persistence.service module
-==================================
-
-.. automodule:: synapse.persistence.service
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.tables.rst b/docs/sphinx/synapse.persistence.tables.rst
deleted file mode 100644
index 907b02769d..0000000000
--- a/docs/sphinx/synapse.persistence.tables.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.persistence.tables module
-=================================
-
-.. automodule:: synapse.persistence.tables
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.persistence.transactions.rst b/docs/sphinx/synapse.persistence.transactions.rst
deleted file mode 100644
index 475c02a8c5..0000000000
--- a/docs/sphinx/synapse.persistence.transactions.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.persistence.transactions module
-=======================================
-
-.. automodule:: synapse.persistence.transactions
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.base.rst b/docs/sphinx/synapse.rest.base.rst
deleted file mode 100644
index 84d2d9b31d..0000000000
--- a/docs/sphinx/synapse.rest.base.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.base module
-========================
-
-.. automodule:: synapse.rest.base
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.events.rst b/docs/sphinx/synapse.rest.events.rst
deleted file mode 100644
index ebbe26c746..0000000000
--- a/docs/sphinx/synapse.rest.events.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.events module
-==========================
-
-.. automodule:: synapse.rest.events
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.register.rst b/docs/sphinx/synapse.rest.register.rst
deleted file mode 100644
index a4a48a8a8f..0000000000
--- a/docs/sphinx/synapse.rest.register.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.register module
-============================
-
-.. automodule:: synapse.rest.register
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.room.rst b/docs/sphinx/synapse.rest.room.rst
deleted file mode 100644
index 63fc5c2840..0000000000
--- a/docs/sphinx/synapse.rest.room.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.rest.room module
-========================
-
-.. automodule:: synapse.rest.room
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rest.rst b/docs/sphinx/synapse.rest.rst
deleted file mode 100644
index 016af926b2..0000000000
--- a/docs/sphinx/synapse.rest.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-synapse.rest package
-====================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.rest.base
- synapse.rest.events
- synapse.rest.register
- synapse.rest.room
-
-Module contents
----------------
-
-.. automodule:: synapse.rest
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.rst b/docs/sphinx/synapse.rst
deleted file mode 100644
index e7869e0e5d..0000000000
--- a/docs/sphinx/synapse.rst
+++ /dev/null
@@ -1,30 +0,0 @@
-synapse package
-===============
-
-Subpackages
------------
-
-.. toctree::
-
- synapse.api
- synapse.app
- synapse.federation
- synapse.persistence
- synapse.rest
- synapse.util
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.server
- synapse.state
-
-Module contents
----------------
-
-.. automodule:: synapse
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.server.rst b/docs/sphinx/synapse.server.rst
deleted file mode 100644
index 7f33f084d7..0000000000
--- a/docs/sphinx/synapse.server.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.server module
-=====================
-
-.. automodule:: synapse.server
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.state.rst b/docs/sphinx/synapse.state.rst
deleted file mode 100644
index 744be2a8be..0000000000
--- a/docs/sphinx/synapse.state.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.state module
-====================
-
-.. automodule:: synapse.state
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.async.rst b/docs/sphinx/synapse.util.async.rst
deleted file mode 100644
index 542bb54444..0000000000
--- a/docs/sphinx/synapse.util.async.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.async module
-=========================
-
-.. automodule:: synapse.util.async
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.dbutils.rst b/docs/sphinx/synapse.util.dbutils.rst
deleted file mode 100644
index afaa9eb749..0000000000
--- a/docs/sphinx/synapse.util.dbutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.dbutils module
-===========================
-
-.. automodule:: synapse.util.dbutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.http.rst b/docs/sphinx/synapse.util.http.rst
deleted file mode 100644
index 344af5a490..0000000000
--- a/docs/sphinx/synapse.util.http.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.http module
-========================
-
-.. automodule:: synapse.util.http
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.lockutils.rst b/docs/sphinx/synapse.util.lockutils.rst
deleted file mode 100644
index 16ee26cabd..0000000000
--- a/docs/sphinx/synapse.util.lockutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.lockutils module
-=============================
-
-.. automodule:: synapse.util.lockutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.logutils.rst b/docs/sphinx/synapse.util.logutils.rst
deleted file mode 100644
index 2b79fa7a4b..0000000000
--- a/docs/sphinx/synapse.util.logutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.logutils module
-============================
-
-.. automodule:: synapse.util.logutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.rst b/docs/sphinx/synapse.util.rst
deleted file mode 100644
index 01a0c3a591..0000000000
--- a/docs/sphinx/synapse.util.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-synapse.util package
-====================
-
-Submodules
-----------
-
-.. toctree::
-
- synapse.util.async
- synapse.util.http
- synapse.util.lockutils
- synapse.util.logutils
- synapse.util.stringutils
-
-Module contents
----------------
-
-.. automodule:: synapse.util
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sphinx/synapse.util.stringutils.rst b/docs/sphinx/synapse.util.stringutils.rst
deleted file mode 100644
index ec626eee28..0000000000
--- a/docs/sphinx/synapse.util.stringutils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-synapse.util.stringutils module
-===============================
-
-.. automodule:: synapse.util.stringutils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index abea432343..32b06aa2c5 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -57,7 +57,7 @@ A custom mapping provider must specify the following methods:
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
* `map_user_attributes(self, userinfo, token)`
- - This method should be async.
+ - This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
@@ -66,6 +66,18 @@ A custom mapping provider must specify the following methods:
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.
+* `get_extra_attributes(self, userinfo, token)`
+ - This method must be async.
+ - Arguments:
+ - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
+ information from.
+ - `token` - A dictionary which includes information necessary to make
+ further requests to the OpenID provider.
+ - Returns a dictionary that is suitable to be serialized to JSON. This
+ will be returned as part of the response during a successful login.
+
+ Note that care should be taken to not overwrite any of the parameters
+ usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
### Default OpenID Mapping Provider
diff --git a/docs/workers.md b/docs/workers.md
index df0ac84d94..84a9759e34 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -243,6 +243,22 @@ for the room are in flight:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
+Additionally, the following endpoints should be included if Synapse is configured
+to use SSO (you only need to include the ones for whichever SSO provider you're
+using):
+
+ # OpenID Connect requests.
+ ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
+ ^/_synapse/oidc/callback$
+
+ # SAML requests.
+ ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
+ ^/_matrix/saml2/authn_response$
+
+ # CAS requests.
+ ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
+ ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
+
Note that a HTTP listener with `client` and `federation` resources must be
configured in the `worker_listeners` option in the worker config.
@@ -303,6 +319,23 @@ stream_writers:
events: event_persister1
```
+#### Background tasks
+
+There is also *experimental* support for moving background tasks to a separate
+worker. Background tasks are run periodically or started via replication. Exactly
+which tasks are configured to run depends on your Synapse configuration (e.g. if
+stats is enabled).
+
+To enable this, the worker must have a `worker_name` and can be configured to run
+background tasks. For example, to move background tasks to a dedicated worker,
+the shared configuration would include:
+
+```yaml
+run_background_tasks_on: background_worker
+```
+
+You might also wish to investigate the `update_user_directory` and
+`media_instance_running_background_jobs` settings.
### `synapse.app.pusher`
diff --git a/mypy.ini b/mypy.ini
index 7986781432..a7ffb81ef1 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -6,6 +6,7 @@ check_untyped_defs = True
show_error_codes = True
show_traceback = True
mypy_path = stubs
+warn_unreachable = True
files =
synapse/api,
synapse/appservice,
@@ -16,6 +17,7 @@ files =
synapse/federation,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
+ synapse/handlers/device.py,
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
@@ -142,3 +144,6 @@ ignore_missing_imports = True
[mypy-nacl.*]
ignore_missing_imports = True
+
+[mypy-hiredis]
+ignore_missing_imports = True
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index d055cf3287..d0685c8b35 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -25,6 +25,7 @@ DISTS = (
"ubuntu:xenial",
"ubuntu:bionic",
"ubuntu:focal",
+ "ubuntu:groovy",
)
DESC = '''\
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
new file mode 100755
index 0000000000..3cde53f5c0
--- /dev/null
+++ b/scripts-dev/complement.sh
@@ -0,0 +1,22 @@
+#! /bin/bash -eu
+# This script is designed for developers who want to test their code
+# against Complement.
+#
+# It makes a Synapse image which represents the current checkout,
+# then downloads Complement and runs it with that image.
+
+cd "$(dirname $0)/.."
+
+# Build the base Synapse image from the local checkout
+docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
+
+# Download Complement
+wget -N https://github.com/matrix-org/complement/archive/master.tar.gz
+tar -xzf master.tar.gz
+cd complement-master
+
+# Build the Synapse image from Complement, based on the above image we just built
+docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles
+
+# Run the tests on the resulting image!
+COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests
diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py
index 9eddb6d515..313860df13 100755
--- a/scripts-dev/definitions.py
+++ b/scripts-dev/definitions.py
@@ -1,7 +1,5 @@
#! /usr/bin/python
-from __future__ import print_function
-
import argparse
import ast
import os
@@ -13,7 +11,7 @@ import yaml
class DefinitionVisitor(ast.NodeVisitor):
def __init__(self):
- super(DefinitionVisitor, self).__init__()
+ super().__init__()
self.functions = {}
self.classes = {}
self.names = {}
diff --git a/scripts-dev/dump_macaroon.py b/scripts-dev/dump_macaroon.py
index 22b30fa78e..980b5e709f 100755
--- a/scripts-dev/dump_macaroon.py
+++ b/scripts-dev/dump_macaroon.py
@@ -1,7 +1,5 @@
#!/usr/bin/env python2
-from __future__ import print_function
-
import sys
import pymacaroons
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index ad12523c4d..abcec48c4f 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -15,8 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import print_function
-
import argparse
import base64
import json
@@ -323,7 +321,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
url = urlparse.urlunparse(
("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
)
- return super(MatrixConnectionAdapter, self).get_connection(url, proxies)
+ return super().get_connection(url, proxies)
if __name__ == "__main__":
diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py
index 89acb52e6a..8d6c3d24db 100644
--- a/scripts-dev/hash_history.py
+++ b/scripts-dev/hash_history.py
@@ -1,5 +1,3 @@
-from __future__ import print_function
-
import sqlite3
import sys
diff --git a/scripts-dev/sphinx_api_docs.sh b/scripts-dev/sphinx_api_docs.sh
deleted file mode 100644
index ee72b29657..0000000000
--- a/scripts-dev/sphinx_api_docs.sh
+++ /dev/null
@@ -1 +0,0 @@
-sphinx-apidoc -o docs/sphinx/ synapse/ -ef
diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py
index b5b63933ab..ab2e763386 100755
--- a/scripts/move_remote_media_to_new_store.py
+++ b/scripts/move_remote_media_to_new_store.py
@@ -32,8 +32,6 @@ To use, pipe the above into::
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
"""
-from __future__ import print_function
-
import argparse
import logging
import os
diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user
index b450712ab7..8b9d30877d 100755
--- a/scripts/register_new_matrix_user
+++ b/scripts/register_new_matrix_user
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import print_function
-
from synapse._scripts.register_new_matrix_user import main
if __name__ == "__main__":
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index a34bdf1830..2d0b59ab53 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -89,6 +89,8 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
+ "users": ["shadow_banned"],
+ "e2e_fallback_keys_json": ["used"],
}
@@ -144,6 +146,7 @@ IGNORED_TABLES = {
# the sessions are transient anyway, so ignore them.
"ui_auth_sessions",
"ui_auth_sessions_credentials",
+ "ui_auth_sessions_ips",
}
@@ -487,7 +490,7 @@ class Porter(object):
hs = MockHomeserver(self.hs_config)
- with make_conn(db_config, engine) as db_conn:
+ with make_conn(db_config, engine, "portdb") as db_conn:
engine.check_database(
db_conn, allow_outdated_version=allow_outdated_version
)
@@ -627,6 +630,7 @@ class Porter(object):
self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
+ await self._setup_events_stream_seqs()
self.progress.done()
except Exception as e:
@@ -803,6 +807,29 @@ class Porter(object):
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
+ def _setup_events_stream_seqs(self):
+ def r(txn):
+ txn.execute("SELECT MAX(stream_ordering) FROM events")
+ curr_id = txn.fetchone()[0]
+ if curr_id:
+ next_id = curr_id + 1
+ txn.execute(
+ "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,)
+ )
+
+ txn.execute("SELECT -MIN(stream_ordering) FROM events")
+ curr_id = txn.fetchone()[0]
+ if curr_id:
+ next_id = curr_id + 1
+ txn.execute(
+ "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s",
+ (next_id,),
+ )
+
+ return self.postgres_store.db_pool.runInteraction(
+ "_setup_events_stream_seqs", r
+ )
+
##############################################
# The following is simply UI stuff
diff --git a/setup.cfg b/setup.cfg
index a32278ea8a..f46e43fad0 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,8 +1,3 @@
-[build_sphinx]
-source-dir = docs/sphinx
-build-dir = docs/build
-all_files = 1
-
[trial]
test_suite = tests
diff --git a/setup.py b/setup.py
index 54ddec8f9f..926b1bc86f 100755
--- a/setup.py
+++ b/setup.py
@@ -94,6 +94,22 @@ ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"]
# Make `pip install matrix-synapse[all]` install all the optional dependencies.
CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
+# Developer dependencies should not get included in "all".
+#
+# We pin black so that our tests don't start failing on new releases.
+CONDITIONAL_REQUIREMENTS["lint"] = [
+ "isort==5.0.3",
+ "black==19.10b0",
+ "flake8-comprehensions",
+ "flake8",
+]
+
+# Dependencies which are exclusively required by unit test code. This is
+# NOT a list of all modules that are necessary to run the unit tests.
+# Tests assume that all optional dependencies are installed.
+#
+# parameterized_class decorator was introduced in parameterized 0.7.0
+CONDITIONAL_REQUIREMENTS["test"] = ["mock>=2.0", "parameterized>=0.7.0"]
setup(
name="matrix-synapse",
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index c66413f003..522244bb57 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -16,7 +16,7 @@
"""Contains *incomplete* type hints for txredisapi.
"""
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Type
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
@@ -42,3 +42,21 @@ def lazyConnection(
class SubscriberFactory:
def buildProtocol(self, addr): ...
+
+class ConnectionHandler: ...
+
+class RedisFactory:
+ continueTrying: bool
+ handler: RedisProtocol
+ def __init__(
+ self,
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: Optional[int] = True,
+ ): ...
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 6b11c5681b..a86dc07ddc 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.20.0rc4"
+__version__ = "1.21.0rc3"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 55cce2db22..da0996edbc 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import print_function
-
import argparse
import getpass
import hashlib
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 75388643ee..1071a0576e 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -218,11 +218,7 @@ class Auth:
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
- expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
- if (
- expiration_ts is not None
- and self.clock.time_msec() >= expiration_ts
- ):
+ if await self.store.is_account_expired(user_id, self.clock.time_msec()):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 46013cde15..592abd844b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -155,3 +155,8 @@ class EventContentFields:
class RoomEncryptionAlgorithms:
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
DEFAULT = MEGOLM_V1_AES_SHA2
+
+
+class AccountDataTypes:
+ DIRECT = "m.direct"
+ IGNORED_USER_LIST = "m.ignored_user_list"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 94a9e58eae..cd6670d0a2 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -87,7 +87,7 @@ class CodeMessageException(RuntimeError):
"""
def __init__(self, code: Union[int, HTTPStatus], msg: str):
- super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
+ super().__init__("%d: %s" % (code, msg))
# Some calls to this method pass instances of http.HTTPStatus for `code`.
# While HTTPStatus is a subclass of int, it has magic __str__ methods
@@ -138,7 +138,7 @@ class SynapseError(CodeMessageException):
msg: The human-readable error message.
errcode: The matrix error code e.g 'M_FORBIDDEN'
"""
- super(SynapseError, self).__init__(code, msg)
+ super().__init__(code, msg)
self.errcode = errcode
def error_dict(self):
@@ -159,7 +159,7 @@ class ProxiedRequestError(SynapseError):
errcode: str = Codes.UNKNOWN,
additional_fields: Optional[Dict] = None,
):
- super(ProxiedRequestError, self).__init__(code, msg, errcode)
+ super().__init__(code, msg, errcode)
if additional_fields is None:
self._additional_fields = {} # type: Dict
else:
@@ -181,7 +181,7 @@ class ConsentNotGivenError(SynapseError):
msg: The human-readable error message
consent_url: The URL where the user can give their consent
"""
- super(ConsentNotGivenError, self).__init__(
+ super().__init__(
code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
)
self._consent_uri = consent_uri
@@ -201,7 +201,7 @@ class UserDeactivatedError(SynapseError):
Args:
msg: The human-readable error message
"""
- super(UserDeactivatedError, self).__init__(
+ super().__init__(
code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
)
@@ -225,7 +225,7 @@ class FederationDeniedError(SynapseError):
self.destination = destination
- super(FederationDeniedError, self).__init__(
+ super().__init__(
code=403,
msg="Federation denied with %s." % (self.destination,),
errcode=Codes.FORBIDDEN,
@@ -244,9 +244,7 @@ class InteractiveAuthIncompleteError(Exception):
"""
def __init__(self, session_id: str, result: "JsonDict"):
- super(InteractiveAuthIncompleteError, self).__init__(
- "Interactive auth not yet complete"
- )
+ super().__init__("Interactive auth not yet complete")
self.session_id = session_id
self.result = result
@@ -261,14 +259,14 @@ class UnrecognizedRequestError(SynapseError):
message = "Unrecognized request"
else:
message = args[0]
- super(UnrecognizedRequestError, self).__init__(400, message, **kwargs)
+ super().__init__(400, message, **kwargs)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND):
- super(NotFoundError, self).__init__(404, msg, errcode=errcode)
+ super().__init__(404, msg, errcode=errcode)
class AuthError(SynapseError):
@@ -279,7 +277,7 @@ class AuthError(SynapseError):
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
- super(AuthError, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
class InvalidClientCredentialsError(SynapseError):
@@ -335,7 +333,7 @@ class ResourceLimitError(SynapseError):
):
self.admin_contact = admin_contact
self.limit_type = limit_type
- super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
+ super().__init__(code, msg, errcode=errcode)
def error_dict(self):
return cs_error(
@@ -352,7 +350,7 @@ class EventSizeError(SynapseError):
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.TOO_LARGE
- super(EventSizeError, self).__init__(413, *args, **kwargs)
+ super().__init__(413, *args, **kwargs)
class EventStreamError(SynapseError):
@@ -361,7 +359,7 @@ class EventStreamError(SynapseError):
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
- super(EventStreamError, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
class LoginError(SynapseError):
@@ -384,7 +382,7 @@ class InvalidCaptchaError(SynapseError):
error_url: Optional[str] = None,
errcode: str = Codes.CAPTCHA_INVALID,
):
- super(InvalidCaptchaError, self).__init__(code, msg, errcode)
+ super().__init__(code, msg, errcode)
self.error_url = error_url
def error_dict(self):
@@ -402,7 +400,7 @@ class LimitExceededError(SynapseError):
retry_after_ms: Optional[int] = None,
errcode: str = Codes.LIMIT_EXCEEDED,
):
- super(LimitExceededError, self).__init__(code, msg, errcode)
+ super().__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
def error_dict(self):
@@ -418,9 +416,7 @@ class RoomKeysVersionError(SynapseError):
Args:
current_version: the current version of the store they should have used
"""
- super(RoomKeysVersionError, self).__init__(
- 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION
- )
+ super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION)
self.current_version = current_version
@@ -429,7 +425,7 @@ class UnsupportedRoomVersionError(SynapseError):
not support."""
def __init__(self, msg: str = "Homeserver does not support this room version"):
- super(UnsupportedRoomVersionError, self).__init__(
+ super().__init__(
code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
)
@@ -440,7 +436,7 @@ class ThreepidValidationError(SynapseError):
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN
- super(ThreepidValidationError, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
class IncompatibleRoomVersionError(SynapseError):
@@ -451,7 +447,7 @@ class IncompatibleRoomVersionError(SynapseError):
"""
def __init__(self, room_version: str):
- super(IncompatibleRoomVersionError, self).__init__(
+ super().__init__(
code=400,
msg="Your homeserver does not support the features required to "
"join this room",
@@ -473,7 +469,7 @@ class PasswordRefusedError(SynapseError):
msg: str = "This password doesn't comply with the server's policy",
errcode: str = Codes.WEAK_PASSWORD,
):
- super(PasswordRefusedError, self).__init__(
+ super().__init__(
code=400, msg=msg, errcode=errcode,
)
@@ -488,7 +484,7 @@ class RequestSendFailed(RuntimeError):
"""
def __init__(self, inner_exception, can_retry):
- super(RequestSendFailed, self).__init__(
+ super().__init__(
"Failed to send request: %s: %s"
% (type(inner_exception).__name__, inner_exception)
)
@@ -542,7 +538,7 @@ class FederationError(RuntimeError):
self.source = source
msg = "%s %s: %s" % (level, code, reason)
- super(FederationError, self).__init__(msg)
+ super().__init__(msg)
def get_dict(self):
return {
@@ -570,7 +566,7 @@ class HttpResponseException(CodeMessageException):
msg: reason phrase from HTTP response status line
response: body of response
"""
- super(HttpResponseException, self).__init__(code, msg)
+ super().__init__(code, msg)
self.response = response
def to_synapse_error(self):
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index bb33345be6..5caf336fd0 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -132,7 +132,7 @@ def matrix_user_id_validator(user_id_str):
class Filtering:
def __init__(self, hs):
- super(Filtering, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
async def get_user_filter(self, user_localpart, filter_id):
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index fb476ddaf5..f6f7b2bf42 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -28,6 +28,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
+from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
@@ -271,9 +272,19 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
+ # Log when we start the shut down process.
+ hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", logger.info, "Shutting down..."
+ )
+
setup_sentry(hs)
setup_sdnotify(hs)
+ # If background tasks are running on the main process, start collecting the
+ # phone home stats.
+ if hs.config.run_background_tasks:
+ start_phone_stats_home(hs)
+
# We now freeze all allocated objects in the hopes that (almost)
# everything currently allocated are things that will be used for the
# rest of time. Doing so means less work each GC (hopefully).
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 7d309b1bb0..f0d65d08d7 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -208,6 +208,7 @@ def start(config_options):
# Explicitly disable background processes
config.update_user_directory = False
+ config.run_background_tasks = False
config.start_pushers = False
config.send_federation = False
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index f985810e88..d53181deb1 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -127,12 +127,16 @@ from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
+from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
from synapse.storage.databases.main.media_repository import MediaRepositoryStore
+from synapse.storage.databases.main.metrics import ServerMetricsStore
from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.databases.main.presence import UserPresenceState
from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
@@ -152,7 +156,7 @@ class PresenceStatusStubServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
- super(PresenceStatusStubServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -176,7 +180,7 @@ class KeyUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(KeyUploadServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.http_client = hs.get_simple_http_client()
@@ -454,6 +458,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
+ StatsStore,
UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
@@ -463,6 +468,7 @@ class GenericWorkerSlavedStore(
SlavedAccountDataStore,
SlavedPusherStore,
CensorEventsStore,
+ ClientIpWorkerStore,
SlavedEventStore,
SlavedKeyStore,
RoomStore,
@@ -476,7 +482,9 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
+ ServerMetricsStore,
SearchWorkerStore,
+ TransactionWorkerStore,
BaseSlavedStore,
):
pass
@@ -646,7 +654,7 @@ class GenericWorkerServer(HomeServer):
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
- super(GenericWorkerReplicationHandler, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index b08319ca77..2b5465417f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -15,18 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import print_function
-
import gc
import logging
-import math
import os
-import resource
import sys
from typing import Iterable
-from prometheus_client import Gauge
-
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
@@ -62,8 +56,6 @@ from synapse.http.server import (
from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.module_api import ModuleApi
from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -113,7 +105,7 @@ class SynapseHomeServer(HomeServer):
additional_resources = listener_config.http_options.additional_resources
logger.debug("Configuring additional resources: %r", additional_resources)
- module_api = ModuleApi(self, self.get_auth_handler())
+ module_api = self.get_module_api()
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
@@ -336,20 +328,6 @@ class SynapseHomeServer(HomeServer):
logger.warning("Unrecognized listener type: %s", listener.type)
-# Gauges to expose monthly active user control metrics
-current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
-current_mau_by_service_gauge = Gauge(
- "synapse_admin_mau_current_mau_by_service",
- "Current MAU by service",
- ["app_service"],
-)
-max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
-registered_reserved_users_mau_gauge = Gauge(
- "synapse_admin_mau:registered_reserved_users",
- "Registered users with reserved threepids",
-)
-
-
def setup(config_options):
"""
Args:
@@ -391,8 +369,6 @@ def setup(config_options):
except UpgradeDatabaseException as e:
quit_with_error("Failed to upgrade database: %s" % (e,))
- hs.setup_master()
-
async def do_acme() -> bool:
"""
Reprovision an ACME certificate, if it's required.
@@ -488,92 +464,6 @@ class SynapseService(service.Service):
return self._port.stopListening()
-# Contains the list of processes we will be monitoring
-# currently either 0 or 1
-_stats_process = []
-
-
-async def phone_stats_home(hs, stats, stats_process=_stats_process):
- logger.info("Gathering stats for reporting")
- now = int(hs.get_clock().time())
- uptime = int(now - hs.start_time)
- if uptime < 0:
- uptime = 0
-
- #
- # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
- #
- old = stats_process[0]
- new = (now, resource.getrusage(resource.RUSAGE_SELF))
- stats_process[0] = new
-
- # Get RSS in bytes
- stats["memory_rss"] = new[1].ru_maxrss
-
- # Get CPU time in % of a single core, not % of all cores
- used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
- old[1].ru_utime + old[1].ru_stime
- )
- if used_cpu_time == 0 or new[0] == old[0]:
- stats["cpu_average"] = 0
- else:
- stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
-
- #
- # General statistics
- #
-
- stats["homeserver"] = hs.config.server_name
- stats["server_context"] = hs.config.server_context
- stats["timestamp"] = now
- stats["uptime_seconds"] = uptime
- version = sys.version_info
- stats["python_version"] = "{}.{}.{}".format(
- version.major, version.minor, version.micro
- )
- stats["total_users"] = await hs.get_datastore().count_all_users()
-
- total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
- stats["total_nonbridged_users"] = total_nonbridged_users
-
- daily_user_type_results = await hs.get_datastore().count_daily_user_type()
- for name, count in daily_user_type_results.items():
- stats["daily_user_type_" + name] = count
-
- room_count = await hs.get_datastore().get_room_count()
- stats["total_room_count"] = room_count
-
- stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
- stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
-
- r30_results = await hs.get_datastore().count_r30_users()
- for name, count in r30_results.items():
- stats["r30_users_" + name] = count
-
- daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
- stats["cache_factor"] = hs.config.caches.global_factor
- stats["event_cache_size"] = hs.config.caches.event_cache_size
-
- #
- # Database version
- #
-
- # This only reports info about the *main* database.
- stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
- stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
-
- logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
- try:
- await hs.get_proxied_http_client().put_json(
- hs.config.report_stats_endpoint, stats
- )
- except Exception as e:
- logger.warning("Error reporting stats: %s", e)
-
-
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
@@ -599,81 +489,6 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
- clock = hs.get_clock()
-
- stats = {}
-
- def performance_stats_init():
- _stats_process.clear()
- _stats_process.append(
- (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
- )
-
- def start_phone_stats_home():
- return run_as_background_process(
- "phone_stats_home", phone_stats_home, hs, stats
- )
-
- def generate_user_daily_visit_stats():
- return run_as_background_process(
- "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
- )
-
- # Rather than update on per session basis, batch up the requests.
- # If you increase the loop period, the accuracy of user_daily_visits
- # table will decrease
- clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
-
- # monthly active user limiting functionality
- def reap_monthly_active_users():
- return run_as_background_process(
- "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
- )
-
- clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
- reap_monthly_active_users()
-
- async def generate_monthly_active_users():
- current_mau_count = 0
- current_mau_count_by_service = {}
- reserved_users = ()
- store = hs.get_datastore()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
- current_mau_count = await store.get_monthly_active_count()
- current_mau_count_by_service = (
- await store.get_monthly_active_count_by_service()
- )
- reserved_users = await store.get_registered_reserved_users()
- current_mau_gauge.set(float(current_mau_count))
-
- for app_service, count in current_mau_count_by_service.items():
- current_mau_by_service_gauge.labels(app_service).set(float(count))
-
- registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
- max_mau_gauge.set(float(hs.config.max_mau_value))
-
- def start_generate_monthly_active_users():
- return run_as_background_process(
- "generate_monthly_active_users", generate_monthly_active_users
- )
-
- start_generate_monthly_active_users()
- if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
- clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
- # End of monthly active user settings
-
- if hs.config.report_stats:
- logger.info("Scheduling stats reporting for 3 hour intervals")
- clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
-
- # We need to defer this init for the cases that we daemonize
- # otherwise the process ID we get is that of the non-daemon process
- clock.call_later(0, performance_stats_init)
-
- # We wait 5 minutes to send the first set of stats as the server can
- # be quite busy the first few minutes
- clock.call_later(5 * 60, start_phone_stats_home)
-
_base.start_reactor(
"synapse-homeserver",
soft_file_limit=hs.config.soft_file_limit,
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
new file mode 100644
index 0000000000..8a69104a04
--- /dev/null
+++ b/synapse/app/phone_stats_home.py
@@ -0,0 +1,198 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import math
+import resource
+import sys
+
+from prometheus_client import Gauge
+
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
+
+logger = logging.getLogger("synapse.app.homeserver")
+
+# Contains the list of processes we will be monitoring
+# currently either 0 or 1
+_stats_process = []
+
+# Gauges to expose monthly active user control metrics
+current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
+current_mau_by_service_gauge = Gauge(
+ "synapse_admin_mau_current_mau_by_service",
+ "Current MAU by service",
+ ["app_service"],
+)
+max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
+registered_reserved_users_mau_gauge = Gauge(
+ "synapse_admin_mau:registered_reserved_users",
+ "Registered users with reserved threepids",
+)
+
+
+@wrap_as_background_process("phone_stats_home")
+async def phone_stats_home(hs, stats, stats_process=_stats_process):
+ logger.info("Gathering stats for reporting")
+ now = int(hs.get_clock().time())
+ uptime = int(now - hs.start_time)
+ if uptime < 0:
+ uptime = 0
+
+ #
+ # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test.
+ #
+ old = stats_process[0]
+ new = (now, resource.getrusage(resource.RUSAGE_SELF))
+ stats_process[0] = new
+
+ # Get RSS in bytes
+ stats["memory_rss"] = new[1].ru_maxrss
+
+ # Get CPU time in % of a single core, not % of all cores
+ used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
+ old[1].ru_utime + old[1].ru_stime
+ )
+ if used_cpu_time == 0 or new[0] == old[0]:
+ stats["cpu_average"] = 0
+ else:
+ stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+
+ #
+ # General statistics
+ #
+
+ stats["homeserver"] = hs.config.server_name
+ stats["server_context"] = hs.config.server_context
+ stats["timestamp"] = now
+ stats["uptime_seconds"] = uptime
+ version = sys.version_info
+ stats["python_version"] = "{}.{}.{}".format(
+ version.major, version.minor, version.micro
+ )
+ stats["total_users"] = await hs.get_datastore().count_all_users()
+
+ total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
+ stats["total_nonbridged_users"] = total_nonbridged_users
+
+ daily_user_type_results = await hs.get_datastore().count_daily_user_type()
+ for name, count in daily_user_type_results.items():
+ stats["daily_user_type_" + name] = count
+
+ room_count = await hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
+
+ stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+
+ r30_results = await hs.get_datastore().count_r30_users()
+ for name, count in r30_results.items():
+ stats["r30_users_" + name] = count
+
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = hs.config.caches.global_factor
+ stats["event_cache_size"] = hs.config.caches.event_cache_size
+
+ #
+ # Database version
+ #
+
+ # This only reports info about the *main* database.
+ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
+
+ #
+ # Logging configuration
+ #
+ synapse_logger = logging.getLogger("synapse")
+ log_level = synapse_logger.getEffectiveLevel()
+ stats["log_level"] = logging.getLevelName(log_level)
+
+ logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ try:
+ await hs.get_proxied_http_client().put_json(
+ hs.config.report_stats_endpoint, stats
+ )
+ except Exception as e:
+ logger.warning("Error reporting stats: %s", e)
+
+
+def start_phone_stats_home(hs):
+ """
+ Start the background tasks which report phone home stats.
+ """
+ clock = hs.get_clock()
+
+ stats = {}
+
+ def performance_stats_init():
+ _stats_process.clear()
+ _stats_process.append(
+ (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
+ )
+
+ # Rather than update on per session basis, batch up the requests.
+ # If you increase the loop period, the accuracy of user_daily_visits
+ # table will decrease
+ clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000)
+
+ # monthly active user limiting functionality
+ def reap_monthly_active_users():
+ return run_as_background_process(
+ "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
+ )
+
+ clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
+ reap_monthly_active_users()
+
+ @wrap_as_background_process("generate_monthly_active_users")
+ async def generate_monthly_active_users():
+ current_mau_count = 0
+ current_mau_count_by_service = {}
+ reserved_users = ()
+ store = hs.get_datastore()
+ if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ current_mau_count = await store.get_monthly_active_count()
+ current_mau_count_by_service = (
+ await store.get_monthly_active_count_by_service()
+ )
+ reserved_users = await store.get_registered_reserved_users()
+ current_mau_gauge.set(float(current_mau_count))
+
+ for app_service, count in current_mau_count_by_service.items():
+ current_mau_by_service_gauge.labels(app_service).set(float(count))
+
+ registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
+ max_mau_gauge.set(float(hs.config.max_mau_value))
+
+ if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
+ generate_monthly_active_users()
+ clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
+ # End of monthly active user settings
+
+ if hs.config.report_stats:
+ logger.info("Scheduling stats reporting for 3 hour intervals")
+ clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
+
+ # We need to defer this init for the cases that we daemonize
+ # otherwise the process ID we get is that of the non-daemon process
+ clock.call_later(0, performance_stats_init)
+
+ # We wait 5 minutes to send the first set of stats as the server can
+ # be quite busy the first few minutes
+ clock.call_later(5 * 60, phone_stats_home, hs, stats)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index bb6fa8299a..c526c28b93 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -88,7 +88,7 @@ class ApplicationServiceApi(SimpleHttpClient):
"""
def __init__(self, hs):
- super(ApplicationServiceApi, self).__init__(hs)
+ super().__init__(hs)
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(
@@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- info = await self.get_json(uri, {})
+ info = await self.get_json(uri)
if not _is_valid_3pe_metadata(info):
logger.warning(
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index bb9bf8598d..85f65da4d9 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -194,7 +194,10 @@ class Config:
return file_stream.read()
def read_templates(
- self, filenames: List[str], custom_template_directory: Optional[str] = None,
+ self,
+ filenames: List[str],
+ custom_template_directory: Optional[str] = None,
+ autoescape: bool = False,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
@@ -210,6 +213,9 @@ class Config:
custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead.
+ autoescape: Whether to autoescape variables before inserting them into the
+ template.
+
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
@@ -233,15 +239,14 @@ class Config:
search_directories.insert(0, custom_template_directory)
loader = jinja2.FileSystemLoader(search_directories)
- env = jinja2.Environment(loader=loader, autoescape=True)
+ env = jinja2.Environment(loader=loader, autoescape=autoescape)
# Update the environment with our custom filters
- env.filters.update(
- {
- "format_ts": _format_ts_filter,
- "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
- }
- )
+ env.filters.update({"format_ts": _format_ts_filter})
+ if self.public_baseurl:
+ env.filters.update(
+ {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)}
+ )
for filename in filenames:
# Load the template
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index cd31b1c3c9..c74969a977 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List
+from typing import Any, Iterable
import jsonschema
@@ -20,7 +20,9 @@ from synapse.config._base import ConfigError
from synapse.types import JsonDict
-def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None:
+def validate_config(
+ json_schema: JsonDict, config: Any, config_path: Iterable[str]
+) -> None:
"""Validates a config setting against a JsonSchema definition
This can be used to validate a section of the config file against a schema
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 82f04d7966..cb00958165 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -28,6 +28,9 @@ class CaptchaConfig(Config):
"recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify",
)
+ self.recaptcha_template = self.read_templates(
+ ["recaptcha.html"], autoescape=True
+ )[0]
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index aec9c4bbce..6efa59b110 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -77,7 +77,7 @@ class ConsentConfig(Config):
section = "consent"
def __init__(self, *args):
- super(ConsentConfig, self).__init__(*args)
+ super().__init__(*args)
self.user_consent_version = None
self.user_consent_template_dir = None
@@ -89,6 +89,8 @@ class ConsentConfig(Config):
def read_config(self, config, **kwargs):
consent_config = config.get("user_consent")
+ self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
+
if consent_config is None:
return
self.user_consent_version = str(consent_config["version"])
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 72b42bfd62..cceffbfee2 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
import email.utils
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 2c77d8f85b..ffd8fca54e 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -17,7 +17,8 @@ from typing import Optional
from netaddr import IPSet
-from ._base import Config, ConfigError
+from synapse.config._base import Config, ConfigError
+from synapse.config._util import validate_config
class FederationConfig(Config):
@@ -52,8 +53,18 @@ class FederationConfig(Config):
"Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
)
+ federation_metrics_domains = config.get("federation_metrics_domains") or []
+ validate_config(
+ _METRICS_FOR_DOMAINS_SCHEMA,
+ federation_metrics_domains,
+ ("federation_metrics_domains",),
+ )
+ self.federation_metrics_domains = set(federation_metrics_domains)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
+ ## Federation ##
+
# Restrict federation to the following whitelist of domains.
# N.B. we recommend also firewalling your federation listener to limit
# inbound federation traffic as early as possible, rather than relying
@@ -85,4 +96,18 @@ class FederationConfig(Config):
- '::1/128'
- 'fe80::/64'
- 'fc00::/7'
+
+ # Report prometheus metrics on the age of PDUs being sent to and received from
+ # the following domains. This can be used to give an idea of "delay" on inbound
+ # and outbound federation, though be aware that any delay can be due to problems
+ # at either end or with the intermediate network.
+ #
+ # By default, no domains are monitored in this way.
+ #
+ #federation_metrics_domains:
+ # - matrix.org
+ # - example.com
"""
+
+
+_METRICS_FOR_DOMAINS_SCHEMA = {"type": "array", "items": {"type": "string"}}
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 556e291495..be65554524 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -92,5 +92,4 @@ class HomeServerConfig(RootConfig):
TracerConfig,
WorkerConfig,
RedisConfig,
- FederationConfig,
]
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index e0939bce84..7597fbc864 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -56,6 +56,8 @@ class OIDCConfig(Config):
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+ self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
+ self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
@@ -158,6 +160,19 @@ class OIDCConfig(Config):
#
#skip_verification: true
+ # Whether to fetch the user profile from the userinfo endpoint. Valid
+ # values are: "auto" or "userinfo_endpoint".
+ #
+ # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
+ # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+ #
+ #user_profile_method: "userinfo_endpoint"
+
+ # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
+ # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
+ #
+ #allow_existing_users: true
+
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
@@ -198,6 +213,14 @@ class OIDCConfig(Config):
# If unset, no displayname will be set.
#
#display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
+
+ # Jinja2 templates for extra attributes to send back to the client during
+ # login.
+ #
+ # Note that these are non-standard and clients will ignore them without modifications.
+ #
+ #extra_attributes:
+ #birthdate: "{{{{ user.birthdate }}}}"
""".format(
mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
)
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index a185655774..d7e3690a32 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -30,7 +30,7 @@ class AccountValidityConfig(Config):
def __init__(self, config, synapse_config):
if config is None:
return
- super(AccountValidityConfig, self).__init__()
+ super().__init__()
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config
@@ -187,6 +187,11 @@ class RegistrationConfig(Config):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime
+ # The success template used during fallback auth.
+ self.fallback_success_template = self.read_templates(
+ ["auth_success.html"], autoescape=True
+ )[0]
+
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 532b910470..85aa49c02d 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -39,7 +39,7 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
-DEFAULT_ROOM_VERSION = "5"
+DEFAULT_ROOM_VERSION = "6"
ROOM_COMPLEXITY_TOO_GREAT = (
"Your homeserver is unable to join rooms this large or complex. "
@@ -641,10 +641,23 @@ class ServerConfig(Config):
"""\
## Server ##
- # The domain name of the server, with optional explicit port.
- # This is used by remote servers to connect to this server,
- # e.g. matrix.org, localhost:8080, etc.
- # This is also the last part of your UserID.
+ # The public-facing domain of the server
+ #
+ # The server_name name will appear at the end of usernames and room addresses
+ # created on this server. For example if the server_name was example.com,
+ # usernames on this server would be in the format @user:example.com
+ #
+ # In most cases you should avoid using a matrix specific subdomain such as
+ # matrix.example.com or synapse.example.com as the server_name for the same
+ # reasons you wouldn't use user@email.example.com as your email address.
+ # See https://github.com/matrix-org/synapse/blob/master/docs/delegate.md
+ # for information on how to host Synapse on a subdomain while preserving
+ # a clean server_name.
+ #
+ # The server_name cannot be changed later so it is important to
+ # configure this correctly before you start Synapse. It should be all
+ # lowercase and may contain an explicit port.
+ # Examples: matrix.org, localhost:8080
#
server_name: "%(server_name)s"
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index 6c427b6f92..57f69dc8e2 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -62,7 +62,7 @@ class ServerNoticesConfig(Config):
section = "servernotices"
def __init__(self, *args):
- super(ServerNoticesConfig, self).__init__(*args)
+ super().__init__(*args)
self.server_notices_mxid = None
self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 62485189ea..b559bfa411 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import division
-
import sys
from ._base import Config
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index e368ea564d..ad37b93c02 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -18,7 +18,7 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
-from typing import List
+from typing import List, Optional
from unpaddedbase64 import encode_base64
@@ -177,8 +177,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
- self.tls_certificate = None
- self.tls_private_key = None
+ self.tls_certificate = None # type: Optional[crypto.X509]
+ self.tls_private_key = None # type: Optional[crypto.PKey]
def is_disk_cert_valid(self, allow_self_signed=True):
"""
@@ -226,12 +226,12 @@ class TlsConfig(Config):
days_remaining = (expires_on - now).days
return days_remaining
- def read_certificate_from_disk(self, require_cert_and_key):
+ def read_certificate_from_disk(self, require_cert_and_key: bool):
"""
Read the certificates and private key from disk.
Args:
- require_cert_and_key (bool): set to True to throw an error if the certificate
+ require_cert_and_key: set to True to throw an error if the certificate
and key file are not given
"""
if require_cert_and_key:
@@ -471,7 +471,6 @@ class TlsConfig(Config):
# or by checking matrix.org/federationtester/api/report?server_name=$host
#
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
-
"""
# Lowercase the string representation of boolean values
% {
@@ -480,13 +479,13 @@ class TlsConfig(Config):
}
)
- def read_tls_certificate(self):
+ def read_tls_certificate(self) -> crypto.X509:
"""Reads the TLS certificate from the configured file, and returns it
Also checks if it is self-signed, and warns if so
Returns:
- OpenSSL.crypto.X509: the certificate
+ The certificate
"""
cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path)
@@ -505,11 +504,11 @@ class TlsConfig(Config):
return cert
- def read_tls_private_key(self):
+ def read_tls_private_key(self) -> crypto.PKey:
"""Reads the TLS private key from the configured file, and returns it
Returns:
- OpenSSL.crypto.PKey: the private key
+ The private key
"""
private_key_path = self.tls_private_key_file
logger.info("Loading TLS key from %s", private_key_path)
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f23e42cdf9..57ab097eba 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -132,6 +132,19 @@ class WorkerConfig(Config):
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
+ # Whether this worker should run background tasks or not.
+ #
+ # As a note for developers, the background tasks guarded by this should
+ # be able to run on only a single instance (meaning that they don't
+ # depend on any in-memory state of a particular worker).
+ #
+ # No effort is made to ensure only a single instance of these tasks is
+ # running.
+ background_tasks_instance = config.get("run_background_tasks_on") or "master"
+ self.run_background_tasks = (
+ self.worker_name is None and background_tasks_instance == "master"
+ ) or self.worker_name == background_tasks_instance
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
## Workers ##
@@ -167,6 +180,11 @@ class WorkerConfig(Config):
#stream_writers:
# events: worker1
# typing: worker1
+
+ # The worker that is used to run background tasks (e.g. cleaning up expired
+ # data). If not provided this defaults to the main process.
+ #
+ #run_background_tasks_on: worker1
"""
def read_arguments(self, args):
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 2b03f5ac76..79668a402e 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -45,7 +45,11 @@ _TLS_VERSION_MAP = {
class ServerContextFactory(ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
- connections."""
+ connections.
+
+ TODO: replace this with an implementation of IOpenSSLServerConnectionCreator,
+ per https://github.com/matrix-org/synapse/issues/1691
+ """
def __init__(self, config):
# TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 32c31b1cd1..c04ad77cf9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -42,7 +42,6 @@ from synapse.api.errors import (
)
from synapse.logging.context import (
PreserveLoggingContext,
- current_context,
make_deferred_yieldable,
preserve_fn,
run_in_background,
@@ -233,8 +232,6 @@ class Keyring:
"""
try:
- ctx = current_context()
-
# map from server name to a set of outstanding request ids
server_to_request_ids = {}
@@ -265,12 +262,8 @@ class Keyring:
# if there are no more requests for this server, we can drop the lock.
if not server_requests:
- with PreserveLoggingContext(ctx):
- logger.debug("Releasing key lookup lock on %s", server_name)
-
- # ... but not immediately, as that can cause stack explosions if
- # we get a long queue of lookups.
- self.clock.call_later(0, drop_server_lock, server_name)
+ logger.debug("Releasing key lookup lock on %s", server_name)
+ drop_server_lock(server_name)
return res
@@ -335,20 +328,32 @@ class Keyring:
)
# look for any requests which weren't satisfied
- with PreserveLoggingContext():
- for verify_request in remaining_requests:
- verify_request.key_ready.errback(
- SynapseError(
- 401,
- "No key for %s with ids in %s (min_validity %i)"
- % (
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
- ),
- Codes.UNAUTHORIZED,
- )
+ while remaining_requests:
+ verify_request = remaining_requests.pop()
+ rq_str = (
+ "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
)
+ )
+
+ # If we run the errback immediately, it may cancel our
+ # loggingcontext while we are still in it, so instead we
+ # schedule it for the next time round the reactor.
+ #
+ # (this also ensures that we don't get a stack overflow if we
+ # has a massive queue of lookups waiting for this server).
+ self.clock.call_later(
+ 0,
+ verify_request.key_ready.errback,
+ SynapseError(
+ 401,
+ "Failed to find any key to satisfy %s" % (rq_str,),
+ Codes.UNAUTHORIZED,
+ ),
+ )
except Exception as err:
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
@@ -410,10 +415,23 @@ class Keyring:
# key was not valid at this point
continue
- with PreserveLoggingContext():
- verify_request.key_ready.callback(
- (server_name, key_id, fetch_key_result.verify_key)
- )
+ # we have a valid key for this request. If we run the callback
+ # immediately, it may cancel our loggingcontext while we are still in
+ # it, so instead we schedule it for the next time round the reactor.
+ #
+ # (this also ensures that we don't get a stack overflow if we had
+ # a massive queue of lookups waiting for this server).
+ logger.debug(
+ "Found key %s:%s for %s",
+ server_name,
+ key_id,
+ verify_request.request_name,
+ )
+ self.clock.call_later(
+ 0,
+ verify_request.key_ready.callback,
+ (server_name, key_id, fetch_key_result.verify_key),
+ )
completed.append(verify_request)
break
@@ -558,7 +576,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
def __init__(self, hs):
- super(PerspectivesKeyFetcher, self).__init__(hs)
+ super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.key_servers = self.config.key_servers
@@ -728,7 +746,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
def __init__(self, hs):
- super(ServerKeyFetcher, self).__init__(hs)
+ super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 5a08e136c4..66c9b97108 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -455,6 +455,8 @@ def check_redaction(
if room_version_obj.event_format == EventFormatVersions.V1:
redacter_domain = get_domain_from_id(event.event_id)
+ if not isinstance(event.redacts, str):
+ return False
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index bf800a3852..7a51d0a22f 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
-from synapse.types import JsonDict
+from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@@ -97,13 +97,16 @@ class DefaultDictProperty(DictProperty):
class _EventInternalMetadata:
- __slots__ = ["_dict"]
+ __slots__ = ["_dict", "stream_ordering"]
def __init__(self, internal_metadata_dict: JsonDict):
# we have to copy the dict, because it turns out that the same dict is
# reused. TODO: fix that
self._dict = dict(internal_metadata_dict)
+ # the stream ordering of this event. None, until it has been persisted.
+ self.stream_ordering = None # type: Optional[int]
+
outlier = DictProperty("outlier") # type: bool
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
@@ -113,13 +116,12 @@ class _EventInternalMetadata:
redacted = DictProperty("redacted") # type: bool
txn_id = DictProperty("txn_id") # type: str
token_id = DictProperty("token_id") # type: str
- stream_ordering = DictProperty("stream_ordering") # type: int
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
# be here
- before = DictProperty("before") # type: str
- after = DictProperty("after") # type: str
+ before = DictProperty("before") # type: RoomStreamToken
+ after = DictProperty("after") # type: RoomStreamToken
order = DictProperty("order") # type: Tuple[int, int]
def get_dict(self) -> JsonDict:
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index b0fc859a47..bad18f7fdf 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -17,24 +17,25 @@
import inspect
from typing import Any, Dict, List, Optional, Tuple
-from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
MYPY = False
if MYPY:
+ import synapse.events
import synapse.server
class SpamChecker:
def __init__(self, hs: "synapse.server.HomeServer"):
self.spam_checkers = [] # type: List[Any]
+ api = hs.get_module_api()
for module, config in hs.config.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
- api = SpamCheckerApi(hs)
self.spam_checkers.append(module(config=config, api=api))
else:
self.spam_checkers.append(module(config=config))
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9d5310851c..1535cc5339 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Callable
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Requester
+from synapse.types import Requester, StateMap
class ThirdPartyEventRules:
@@ -38,7 +39,7 @@ class ThirdPartyEventRules:
if module is not None:
self.third_party_rules = module(
- config=config, http_client=hs.get_simple_http_client()
+ config=config, module_api=hs.get_module_api(),
)
async def check_event_allowed(
@@ -59,12 +60,14 @@ class ThirdPartyEventRules:
prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database.
- state_events = {}
- for key, event_id in prev_state_ids.items():
- state_events[key] = await self.store.get_event(event_id, allow_none=True)
+ events = await self.store.get_events(prev_state_ids.values())
+ state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
- ret = await self.third_party_rules.check_event_allowed(event, state_events)
- return ret
+ # The module can modify the event slightly if it wants, but caution should be
+ # exercised, and it's likely to go very wrong if applied to events received over
+ # federation.
+
+ return await self.third_party_rules.check_event_allowed(event, state_events)
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
@@ -106,6 +109,48 @@ class ThirdPartyEventRules:
if self.third_party_rules is None:
return True
+ state_events = await self._get_state_map_for_room(room_id)
+
+ ret = await self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ return ret
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, new_visibility: str
+ ) -> bool:
+ """Check if a room is allowed to be published to, or removed from, the public room
+ list.
+
+ Args:
+ room_id: The ID of the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ True if the room's visibility can be modified, False if not.
+ """
+ if self.third_party_rules is None:
+ return True
+
+ check_func = getattr(
+ self.third_party_rules, "check_visibility_can_be_modified", None
+ )
+ if not check_func or not isinstance(check_func, Callable):
+ return True
+
+ state_events = await self._get_state_map_for_room(room_id)
+
+ return await check_func(room_id, state_events, new_visibility)
+
+ async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
+ """Given a room ID, return the state events of that room.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ A dict mapping (event type, state key) to state event.
+ """
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
@@ -113,7 +158,4 @@ class ThirdPartyEventRules:
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
- ret = await self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events
- )
- return ret
+ return state_events
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 32c73d3413..355cbe05f1 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -49,6 +49,11 @@ def prune_event(event: EventBase) -> EventBase:
pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
)
+ # copy the internal fields
+ pruned_event.internal_metadata.stream_ordering = (
+ event.internal_metadata.stream_ordering
+ )
+
# Mark the event as redacted
pruned_event.internal_metadata.redacted = True
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 063605eaff..c8936a28ea 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -26,10 +26,12 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
+ Union,
)
from prometheus_client import Counter
@@ -81,7 +83,7 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase):
def __init__(self, hs):
- super(FederationClient, self).__init__(hs)
+ super().__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
@@ -219,11 +221,9 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
- # FIXME: We should handle signature failures more gracefully.
- pdus[:] = await make_deferred_yieldable(
- defer.gatherResults(
- self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ # Check signatures and hash of pdus, removing any from the list that fail checks
+ pdus[:] = await self._check_sigs_and_hash_and_fetch(
+ dest, pdus, outlier=True, room_version=room_version
)
return pdus
@@ -505,7 +505,7 @@ class FederationClient(FederationBase):
user_id: str,
membership: str,
content: dict,
- params: Dict[str, str],
+ params: Optional[Mapping[str, Union[str, Iterable[str]]]],
) -> Tuple[str, EventBase, RoomVersion]:
"""
Creates an m.room.member event, with context, without participating in the room.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 662325bab1..6035d2f664 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,13 +22,12 @@ from typing import (
Callable,
Dict,
List,
- Match,
Optional,
Tuple,
Union,
)
-from prometheus_client import Counter, Histogram
+from prometheus_client import Counter, Gauge, Histogram
from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
@@ -88,19 +87,32 @@ pdu_process_time = Histogram(
)
+last_pdu_age_metric = Gauge(
+ "synapse_federation_last_received_pdu_age",
+ "The age (in seconds) of the last PDU successfully received from the given domain",
+ labelnames=("server_name",),
+)
+
+
class FederationServer(FederationBase):
def __init__(self, hs):
- super(FederationServer, self).__init__(hs)
+ super().__init__(hs)
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
self.state = hs.get_state_handler()
self.device_handler = hs.get_device_handler()
+ self._federation_ratelimiter = hs.get_federation_ratelimiter()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
+ # We cache results for transaction with the same ID
+ self._transaction_resp_cache = ResponseCache(
+ hs, "fed_txn_handler", timeout_ms=30000
+ )
+
self.transaction_actions = TransactionActions(self.store)
self.registry = hs.get_federation_registry()
@@ -112,6 +124,10 @@ class FederationServer(FederationBase):
hs, "state_ids_resp", timeout_ms=30000
)
+ self._federation_metrics_domains = (
+ hs.get_config().federation.federation_metrics_domains
+ )
+
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@@ -135,22 +151,44 @@ class FederationServer(FederationBase):
request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data)
+ transaction_id = transaction.transaction_id # type: ignore
- if not transaction.transaction_id: # type: ignore
+ if not transaction_id:
raise Exception("Transaction missing transaction_id")
- logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore
+ logger.debug("[%s] Got transaction", transaction_id)
- # use a linearizer to ensure that we don't process the same transaction
- # multiple times in parallel.
- with (
- await self._transaction_linearizer.queue(
- (origin, transaction.transaction_id) # type: ignore
- )
- ):
- result = await self._handle_incoming_transaction(
- origin, transaction, request_time
- )
+ # We wrap in a ResponseCache so that we de-duplicate retried
+ # transactions.
+ return await self._transaction_resp_cache.wrap(
+ (origin, transaction_id),
+ self._on_incoming_transaction_inner,
+ origin,
+ transaction,
+ request_time,
+ )
+
+ async def _on_incoming_transaction_inner(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Tuple[int, Dict[str, Any]]:
+ # Use a linearizer to ensure that transactions from a remote are
+ # processed in order.
+ with await self._transaction_linearizer.queue(origin):
+ # We rate limit here *after* we've queued up the incoming requests,
+ # so that we don't fill up the ratelimiter with blocked requests.
+ #
+ # This is important as the ratelimiter allows N concurrent requests
+ # at a time, and only starts ratelimiting if there are more requests
+ # than that being processed at a time. If we queued up requests in
+ # the linearizer/response cache *after* the ratelimiting then those
+ # queued up requests would count as part of the allowed limit of N
+ # concurrent requests.
+ with self._federation_ratelimiter.ratelimit(origin) as d:
+ await d
+
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
+ )
return result
@@ -234,7 +272,11 @@ class FederationServer(FederationBase):
pdus_by_room = {} # type: Dict[str, List[EventBase]]
+ newest_pdu_ts = 0
+
for p in transaction.pdus: # type: ignore
+ # FIXME (richardv): I don't think this works:
+ # https://github.com/matrix-org/synapse/issues/8429
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
@@ -272,6 +314,9 @@ class FederationServer(FederationBase):
event = event_from_pdu_json(p, room_version)
pdus_by_room.setdefault(room_id, []).append(event)
+ if event.origin_server_ts > newest_pdu_ts:
+ newest_pdu_ts = event.origin_server_ts
+
pdu_results = {}
# we can process different rooms in parallel (which is useful if they
@@ -312,6 +357,10 @@ class FederationServer(FederationBase):
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
+ if newest_pdu_ts and origin in self._federation_metrics_domains:
+ newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
+ last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+
return pdu_results
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
@@ -801,14 +850,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
return False
-def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
+def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
- return regex.match(server_name)
+ return bool(regex.match(server_name))
class FederationHandlerRegistry:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 41a726878d..e33b29a42c 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -55,6 +55,15 @@ sent_pdus_destination_dist_total = Counter(
"Total number of PDUs queued for sending across all destinations",
)
+# Time (in s) after Synapse's startup that we will begin to wake up destinations
+# that have catch-up outstanding.
+CATCH_UP_STARTUP_DELAY_SEC = 15
+
+# Time (in s) to wait in between waking up each destination, i.e. one destination
+# will be woken up every <x> seconds after Synapse's startup until we have woken
+# every destination has outstanding catch-up.
+CATCH_UP_STARTUP_INTERVAL_SEC = 5
+
class FederationSender:
def __init__(self, hs: "synapse.server.HomeServer"):
@@ -125,6 +134,14 @@ class FederationSender:
1000.0 / hs.config.federation_rr_transactions_per_room_per_second
)
+ # wake up destinations that have outstanding PDUs to be caught up
+ self._catchup_after_startup_timer = self.clock.call_later(
+ CATCH_UP_STARTUP_DELAY_SEC,
+ run_as_background_process,
+ "wake_destinations_needing_catchup",
+ self._wake_destinations_needing_catchup,
+ )
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -280,6 +297,8 @@ class FederationSender:
sent_pdus_destination_dist_total.inc(len(destinations))
sent_pdus_destination_dist_count.inc()
+ assert pdu.internal_metadata.stream_ordering
+
# track the fact that we have a PDU for these destinations,
# to allow us to perform catch-up later on if the remote is unreachable
# for a while.
@@ -560,3 +579,37 @@ class FederationSender:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return [], 0, False
+
+ async def _wake_destinations_needing_catchup(self):
+ """
+ Wakes up destinations that need catch-up and are not currently being
+ backed off from.
+
+ In order to reduce load spikes, adds a delay between each destination.
+ """
+
+ last_processed = None # type: Optional[str]
+
+ while True:
+ destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
+ last_processed
+ )
+
+ if not destinations_to_wake:
+ # finished waking all destinations!
+ self._catchup_after_startup_timer = None
+ break
+
+ destinations_to_wake = [
+ d
+ for d in destinations_to_wake
+ if self._federation_shard_config.should_handle(self._instance_name, d)
+ ]
+
+ for last_processed in destinations_to_wake:
+ logger.info(
+ "Destination %s has outstanding catch-up, waking up.",
+ last_processed,
+ )
+ self.wake_destination(last_processed)
+ await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 2657767fd1..db8e456fe8 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -158,6 +158,7 @@ class PerDestinationQueue:
# yet know if we have anything to catch up (None)
self._pending_pdus.append(pdu)
else:
+ assert pdu.internal_metadata.stream_ordering
self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
self.attempt_new_transaction()
@@ -361,6 +362,7 @@ class PerDestinationQueue:
last_successful_stream_ordering = (
final_pdu.internal_metadata.stream_ordering
)
+ assert last_successful_stream_ordering
await self._store.set_destination_last_successful_stream_ordering(
self._destination, last_successful_stream_ordering
)
@@ -490,7 +492,7 @@ class PerDestinationQueue:
)
if logger.isEnabledFor(logging.INFO):
- rooms = (p.room_id for p in catchup_pdus)
+ rooms = [p.room_id for p in catchup_pdus]
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
success = await self._transaction_manager.send_new_transaction(
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index c84072ab73..3e07f925e0 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -15,6 +15,8 @@
import logging
from typing import TYPE_CHECKING, List
+from prometheus_client import Gauge
+
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -34,6 +36,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+last_pdu_age_metric = Gauge(
+ "synapse_federation_last_sent_pdu_age",
+ "The age (in seconds) of the last PDU successfully sent to the given domain",
+ labelnames=("server_name",),
+)
+
class TransactionManager:
"""Helper class which handles building and sending transactions
@@ -48,6 +56,10 @@ class TransactionManager:
self._transaction_actions = TransactionActions(self._store)
self._transport_layer = hs.get_federation_transport_client()
+ self._federation_metrics_domains = (
+ hs.get_config().federation.federation_metrics_domains
+ )
+
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
@@ -119,6 +131,9 @@ class TransactionManager:
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
+ # FIXME (richardv): I also believe it no longer works. We (now?) store
+ # "age_ts" in "unsigned" rather than at the top level. See
+ # https://github.com/matrix-org/synapse/issues/8429.
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
@@ -167,5 +182,12 @@ class TransactionManager:
)
success = False
+ if success and pdus and destination in self._federation_metrics_domains:
+ last_pdu = pdus[-1]
+ last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
+ last_pdu_age_metric.labels(server_name=destination).set(
+ last_pdu_age / 1000
+ )
+
set_tag(tags.ERROR, not success)
return success
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index e04704d10c..a2fb558b45 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -46,7 +46,6 @@ from synapse.logging.opentracing import (
)
from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@@ -70,12 +69,10 @@ class TransportLayerServer(JsonResource):
self.clock = hs.get_clock()
self.servlet_groups = servlet_groups
- super(TransportLayerServer, self).__init__(hs, canonical_json=False)
+ super().__init__(hs, canonical_json=False)
self.authenticator = Authenticator(hs)
- self.ratelimiter = FederationRateLimiter(
- self.clock, config=hs.config.rc_federation
- )
+ self.ratelimiter = hs.get_federation_ratelimiter()
self.register_servlets()
@@ -273,6 +270,8 @@ class BaseFederationServlet:
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
+ RATELIMIT = True # Whether to rate limit requests or not
+
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
@@ -336,7 +335,7 @@ class BaseFederationServlet:
)
with scope:
- if origin:
+ if origin and self.RATELIMIT:
with ratelimiter.ratelimit(origin) as d:
await d
if request._disconnected:
@@ -373,10 +372,12 @@ class BaseFederationServlet:
class FederationSendServlet(BaseFederationServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
+ # We ratelimit manually in the handler as we queue up the requests and we
+ # don't want to fill up the ratelimiter with blocked requests.
+ RATELIMIT = False
+
def __init__(self, handler, server_name, **kwargs):
- super(FederationSendServlet, self).__init__(
- handler, server_name=server_name, **kwargs
- )
+ super().__init__(handler, server_name=server_name, **kwargs)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
@@ -787,9 +788,7 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
- super(PublicRoomList, self).__init__(
- handler, authenticator, ratelimiter, server_name
- )
+ super().__init__(handler, authenticator, ratelimiter, server_name)
self.allow_access = allow_access
async def on_GET(self, origin, content, query):
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 1dd20ee4e1..e5f85b472d 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -336,7 +336,7 @@ class GroupsServerWorkerHandler:
class GroupsServerHandler(GroupsServerWorkerHandler):
def __init__(self, hs):
- super(GroupsServerHandler, self).__init__(hs)
+ super().__init__(hs)
# Ensure attestations get renewed
hs.get_groups_attestation_renewer()
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5e5a64037d..1ce2091b46 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
def __init__(self, hs):
- super(AdminHandler, self).__init__(hs)
+ super().__init__(hs)
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
if not events:
break
- from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
+ from_key = events[-1].internal_metadata.after
events = await filter_events_for_client(self.storage, user_id, events)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4e658d9a48..f6d17c53b1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
}
+@attr.s(slots=True)
+class SsoLoginExtraAttributes:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib(type=int)
+ extra_attributes = attr.ib(type=JsonDict)
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -145,7 +154,7 @@ class AuthHandler(BaseHandler):
Args:
hs (synapse.server.HomeServer):
"""
- super(AuthHandler, self).__init__(hs)
+ super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
@@ -155,7 +164,14 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
+ # we can't use hs.get_module_api() here, because to do so will create an
+ # import loop.
+ #
+ # TODO: refactor this class to separate the lower-level stuff that
+ # ModuleApi can use from the higher-level stuff that uses ModuleApi, as
+ # better way to break the loop
account_handler = ModuleApi(hs, self)
+
self.password_providers = [
module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers
@@ -203,7 +219,7 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
- if hs.config.worker_app is None:
+ if hs.config.run_background_tasks:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
@@ -239,6 +255,10 @@ class AuthHandler(BaseHandler):
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+ # A mapping of user ID to extra attributes to include in the login
+ # response.
+ self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
+
async def validate_user_via_ui_auth(
self,
requester: Requester,
@@ -1165,6 +1185,7 @@ class AuthHandler(BaseHandler):
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
+ extra_attributes: Optional[JsonDict] = None,
):
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1173,6 +1194,8 @@ class AuthHandler(BaseHandler):
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
+ extra_attributes: Extra attributes which will be passed to the client
+ during successful login. Must be JSON serializable.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@@ -1181,19 +1204,30 @@ class AuthHandler(BaseHandler):
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
- self._complete_sso_login(registered_user_id, request, client_redirect_url)
+ self._complete_sso_login(
+ registered_user_id, request, client_redirect_url, extra_attributes
+ )
def _complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
+ extra_attributes: Optional[JsonDict] = None,
):
"""
The synchronous portion of complete_sso_login.
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
"""
+ # Store any extra attributes which will be passed in the login response.
+ # Note that this is per-user so it may overwrite a previous value, this
+ # is considered OK since the newest SSO attributes should be most valid.
+ if extra_attributes:
+ self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
+ self._clock.time_msec(), extra_attributes,
+ )
+
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
@@ -1226,6 +1260,37 @@ class AuthHandler(BaseHandler):
)
respond_with_html(request, 200, html)
+ async def _sso_login_callback(self, login_result: JsonDict) -> None:
+ """
+ A login callback which might add additional attributes to the login response.
+
+ Args:
+ login_result: The data to be sent to the client. Includes the user
+ ID and access token.
+ """
+ # Expire attributes before processing. Note that there shouldn't be any
+ # valid logins that still have extra attributes.
+ self._expire_sso_extra_attributes()
+
+ extra_attributes = self._extra_attributes.get(login_result["user_id"])
+ if extra_attributes:
+ login_result.update(extra_attributes.extra_attributes)
+
+ def _expire_sso_extra_attributes(self) -> None:
+ """
+ Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid.
+ """
+ # TODO This should match the amount of time the macaroon is valid for.
+ LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000
+ expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME
+ to_expire = set()
+ for user_id, data in self._extra_attributes.items():
+ if data.creation_time < expire_before:
+ to_expire.add(user_id)
+ for user_id in to_expire:
+ logger.debug("Expiring extra attributes for user %s", user_id)
+ del self._extra_attributes[user_id]
+
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 25169157c1..0635ad5708 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -29,7 +29,7 @@ class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
def __init__(self, hs):
- super(DeactivateAccountHandler, self).__init__(hs)
+ super().__init__(hs)
self.hs = hs
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 4b0a4f96cc..debb1b4f29 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,11 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.api.errors import (
+ Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
@@ -28,8 +29,10 @@ from synapse.api.errors import (
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
- RoomStreamToken,
+ Collection,
+ JsonDict,
StreamToken,
+ UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@@ -41,14 +44,17 @@ from synapse.util.retryutils import NotRetryingDestination
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100
class DeviceWorkerHandler(BaseHandler):
- def __init__(self, hs):
- super(DeviceWorkerHandler, self).__init__(hs)
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
@@ -105,15 +111,16 @@ class DeviceWorkerHandler(BaseHandler):
@trace
@measure_func("device.get_user_ids_changed")
- async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
+ async def get_user_ids_changed(
+ self, user_id: str, from_token: StreamToken
+ ) -> JsonDict:
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
"""
set_tag("user_id", user_id)
set_tag("from_token", from_token)
- now_room_id = self.store.get_room_max_stream_ordering()
- now_room_key = RoomStreamToken(None, now_room_id)
+ now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -222,8 +229,8 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else:
- possibly_joined = []
- possibly_left = []
+ possibly_joined = set()
+ possibly_left = set()
result = {"changed": list(possibly_joined), "left": list(possibly_left)}
@@ -231,7 +238,7 @@ class DeviceWorkerHandler(BaseHandler):
return result
- async def on_federation_query_user_devices(self, user_id):
+ async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id
)
@@ -250,8 +257,8 @@ class DeviceWorkerHandler(BaseHandler):
class DeviceHandler(DeviceWorkerHandler):
- def __init__(self, hs):
- super(DeviceHandler, self).__init__(hs)
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.federation_sender = hs.get_federation_sender()
@@ -265,9 +272,30 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
+ def _check_device_name_length(self, name: Optional[str]):
+ """
+ Checks whether a device name is longer than the maximum allowed length.
+
+ Args:
+ name: The name of the device.
+
+ Raises:
+ SynapseError: if the device name is too long.
+ """
+ if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN:
+ raise SynapseError(
+ 400,
+ "Device display name is too long (max %i)"
+ % (MAX_DEVICE_DISPLAY_NAME_LEN,),
+ errcode=Codes.TOO_LARGE,
+ )
+
async def check_device_registered(
- self, user_id, device_id, initial_device_display_name=None
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_device_display_name: Optional[str] = None,
+ ) -> str:
"""
If the given device has not been registered, register it with the
supplied display name.
@@ -275,13 +303,15 @@ class DeviceHandler(DeviceWorkerHandler):
If no device_id is supplied, we make one up.
Args:
- user_id (str): @user:id
- device_id (str | None): device id supplied by client
- initial_device_display_name (str | None): device display name from
- client
+ user_id: @user:id
+ device_id: device id supplied by client
+ initial_device_display_name: device display name from client
Returns:
- str: device id (generated if none was supplied)
+ device id (generated if none was supplied)
"""
+
+ self._check_device_name_length(initial_device_display_name)
+
if device_id is not None:
new_device = await self.store.store_device(
user_id=user_id,
@@ -296,15 +326,15 @@ class DeviceHandler(DeviceWorkerHandler):
# times in case of a clash.
attempts = 0
while attempts < 5:
- device_id = stringutils.random_string(10).upper()
+ new_device_id = stringutils.random_string(10).upper()
new_device = await self.store.store_device(
user_id=user_id,
- device_id=device_id,
+ device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
- await self.notify_device_update(user_id, [device_id])
- return device_id
+ await self.notify_device_update(user_id, [new_device_id])
+ return new_device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -397,12 +427,8 @@ class DeviceHandler(DeviceWorkerHandler):
# Reject a new displayname which is too long.
new_display_name = content.get("display_name")
- if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN:
- raise SynapseError(
- 400,
- "Device display name is too long (max %i)"
- % (MAX_DEVICE_DISPLAY_NAME_LEN,),
- )
+
+ self._check_device_name_length(new_display_name)
try:
await self.store.update_device(
@@ -417,7 +443,9 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
- async def notify_device_update(self, user_id, device_ids):
+ async def notify_device_update(
+ self, user_id: str, device_ids: Collection[str]
+ ) -> None:
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
@@ -429,7 +457,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id
)
- hosts = set()
+ hosts = set() # type: Set[str]
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
@@ -481,7 +509,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
- async def user_left_room(self, user, room_id):
+ async def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
@@ -489,8 +517,89 @@ class DeviceHandler(DeviceWorkerHandler):
# receive device updates. Mark this in DB.
await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+ async def store_dehydrated_device(
+ self,
+ user_id: str,
+ device_data: JsonDict,
+ initial_device_display_name: Optional[str] = None,
+ ) -> str:
+ """Store a dehydrated device for a user. If the user had a previous
+ dehydrated device, it is removed.
+
+ Args:
+ user_id: the user that we are storing the device for
+ device_data: the dehydrated device information
+ initial_device_display_name: The display name to use for the device
+ Returns:
+ device id of the dehydrated device
+ """
+ device_id = await self.check_device_registered(
+ user_id, None, initial_device_display_name,
+ )
+ old_device_id = await self.store.store_dehydrated_device(
+ user_id, device_id, device_data
+ )
+ if old_device_id is not None:
+ await self.delete_device(user_id, old_device_id)
+ return device_id
+
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
-def _update_device_from_client_ips(device, client_ips):
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ return await self.store.get_dehydrated_device(user_id)
+
+ async def rehydrate_device(
+ self, user_id: str, access_token: str, device_id: str
+ ) -> dict:
+ """Process a rehydration request from the user.
+
+ Args:
+ user_id: the user who is rehydrating the device
+ access_token: the access token used for the request
+ device_id: the ID of the device that will be rehydrated
+ Returns:
+ a dict containing {"success": True}
+ """
+ success = await self.store.remove_dehydrated_device(user_id, device_id)
+
+ if not success:
+ raise errors.NotFoundError()
+
+ # If the dehydrated device was successfully deleted (the device ID
+ # matched the stored dehydrated device), then modify the access
+ # token to use the dehydrated device's ID and copy the old device
+ # display name to the dehydrated device, and destroy the old device
+ # ID
+ old_device_id = await self.store.set_device_for_access_token(
+ access_token, device_id
+ )
+ old_device = await self.store.get_device(user_id, old_device_id)
+ await self.store.update_device(user_id, device_id, old_device["display_name"])
+ # can't call self.delete_device because that will clobber the
+ # access token so call the storage layer directly
+ await self.store.delete_device(user_id, old_device_id)
+ await self.store.delete_e2e_keys_by_device(
+ user_id=user_id, device_id=old_device_id
+ )
+
+ # tell everyone that the old device is gone and that the dehydrated
+ # device has a new display name
+ await self.notify_device_update(user_id, [old_device_id, device_id])
+
+ return {"success": True}
+
+
+def _update_device_from_client_ips(
+ device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -498,7 +607,7 @@ def _update_device_from_client_ips(device, client_ips):
class DeviceListUpdater:
"Handles incoming device list updates from federation and updates the DB"
- def __init__(self, hs, device_handler):
+ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@@ -507,7 +616,9 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {}
+ self._pending_updates = (
+ {}
+ ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@@ -530,7 +641,9 @@ class DeviceListUpdater:
)
@trace
- async def incoming_device_list_update(self, origin, edu_content):
+ async def incoming_device_list_update(
+ self, origin: str, edu_content: JsonDict
+ ) -> None:
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
@@ -591,7 +704,7 @@ class DeviceListUpdater:
await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update")
- async def _handle_device_updates(self, user_id):
+ async def _handle_device_updates(self, user_id: str) -> None:
"Actually handle pending updates."
with (await self._remote_edu_linearizer.queue(user_id)):
@@ -639,7 +752,9 @@ class DeviceListUpdater:
stream_id for _, stream_id, _, _ in pending_updates
)
- async def _need_to_do_resync(self, user_id, updates):
+ async def _need_to_do_resync(
+ self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
+ ) -> bool:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
@@ -670,7 +785,7 @@ class DeviceListUpdater:
return False
@trace
- async def _maybe_retry_device_resync(self):
+ async def _maybe_retry_device_resync(self) -> None:
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
@@ -713,7 +828,7 @@ class DeviceListUpdater:
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
@@ -737,7 +852,7 @@ class DeviceListUpdater:
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)
- return
+ return None
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s", user_id, e,
@@ -754,12 +869,12 @@ class DeviceListUpdater:
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
- return
+ return None
except FederationDeniedError as e:
set_tag("error", True)
log_kv({"reason": "FederationDeniedError"})
logger.info(e)
- return
+ return None
except Exception as e:
set_tag("error", True)
log_kv(
@@ -772,7 +887,7 @@ class DeviceListUpdater:
# it later.
await self.store.mark_remote_user_device_cache_as_stale(user_id)
- return
+ return None
log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
@@ -833,7 +948,7 @@ class DeviceListUpdater:
user_id: str,
master_key: Optional[Dict[str, Any]],
self_signing_key: Optional[Dict[str, Any]],
- ) -> list:
+ ) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.
Args:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 46826eb784..ad5683d251 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class DirectoryHandler(BaseHandler):
def __init__(self, hs):
- super(DirectoryHandler, self).__init__(hs)
+ super().__init__(hs)
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
@@ -46,6 +46,7 @@ class DirectoryHandler(BaseHandler):
self.config = hs.config
self.enable_room_list_search = hs.config.enable_room_list_search
self.require_membership = hs.config.require_membership_for_aliases
+ self.third_party_event_rules = hs.get_third_party_event_rules()
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -383,7 +384,7 @@ class DirectoryHandler(BaseHandler):
"""
creator = await self.store.get_room_alias_creator(alias.to_string())
- if creator is not None and creator == user_id:
+ if creator == user_id:
return True
# Resolve the alias to the corresponding room.
@@ -454,6 +455,15 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
+ # Check if publishing is blocked by a third party module
+ allowed_by_third_party_rules = await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
+ )
+ if not allowed_by_third_party_rules:
+ raise SynapseError(403, "Not allowed to publish room")
+
await self.store.set_room_is_public(room_id, making_public)
async def edit_published_appservice_room_list(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index dd40fd1299..611742ae72 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -496,6 +496,22 @@ class E2eKeysHandler:
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
+ fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+ if fallback_keys and isinstance(fallback_keys, dict):
+ log_kv(
+ {
+ "message": "Updating fallback_keys for device.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ )
+ await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
+ elif fallback_keys:
+ log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
+ else:
+ log_kv(
+ {"message": "Did not update fallback_keys", "reason": "no keys given"}
+ )
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index fdce54c5c3..539b4fc32e 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
- super(EventStreamHandler, self).__init__(hs)
+ super().__init__(hs)
self.clock = hs.get_clock()
@@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler):
chunk = {
"chunk": chunks,
- "start": tokens[0].to_string(),
- "end": tokens[1].to_string(),
+ "start": await tokens[0].to_string(self.store),
+ "end": await tokens[1].to_string(self.store),
}
return chunk
@@ -142,7 +142,7 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
- super(EventHandler, self).__init__(hs)
+ super().__init__(hs)
self.storage = hs.get_storage()
async def get_event(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c3eb5f63d7..e5ddcd2171 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -22,7 +22,7 @@ import itertools
import logging
from collections.abc import Container
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr
from signedjson.key import decode_verify_key_bytes
@@ -70,11 +70,13 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnInviteRestServlet,
)
-from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
JsonDict,
MutableStateMap,
+ PersistedEventPosition,
+ RoomStreamToken,
StateMap,
UserID,
get_domain_from_id,
@@ -84,6 +86,9 @@ from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -115,8 +120,8 @@ class FederationHandler(BaseHandler):
rooms.
"""
- def __init__(self, hs):
- super(FederationHandler, self).__init__(hs)
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.hs = hs
@@ -125,6 +130,7 @@ class FederationHandler(BaseHandler):
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator()
@@ -154,8 +160,9 @@ class FederationHandler(BaseHandler):
self._device_list_updater = hs.get_device_handler().device_list_updater
self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
- # When joining a room we need to queue any events for that room up
- self.room_queues = {}
+ # When joining a room we need to queue any events for that room up.
+ # For each room, a list of (pdu, origin) tuples.
+ self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -280,7 +287,7 @@ class FederationHandler(BaseHandler):
raise Exception(
"Error fetching missing prev_events for %s: %s"
% (event_id, e)
- )
+ ) from e
# Update the set of things we've seen after trying to
# fetch the missing stuff
@@ -379,8 +386,7 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id)
- state_map = await resolve_events_with_store(
- self.clock,
+ state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
@@ -813,6 +819,9 @@ class FederationHandler(BaseHandler):
dest, room_id, limit=limit, extremities=extremities
)
+ if not events:
+ return []
+
# ideally we'd sanity check the events here for excess prev_events etc,
# but it's hard to reject events at this point without completely
# breaking backfill in the same way that it is currently broken by
@@ -918,15 +927,26 @@ class FederationHandler(BaseHandler):
return events
- async def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(
+ self, room_id: str, current_depth: int, limit: int
+ ) -> bool:
"""Checks the database to see if we should backfill before paginating,
and if so do.
+
+ Args:
+ room_id
+ current_depth: The depth from which we're paginating from. This is
+ used to decide if we should backfill and what extremities to
+ use.
+ limit: The number of events that the pagination request will
+ return. This is used as part of the heuristic to decide if we
+ should back paginate.
"""
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
- return
+ return False
# We only want to paginate if we can actually see the events we'll get,
# as otherwise we'll just spend a lot of resources to get redacted
@@ -979,16 +999,54 @@ class FederationHandler(BaseHandler):
sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
+ # If we're approaching an extremity we trigger a backfill, otherwise we
+ # no-op.
+ #
+ # We chose twice the limit here as then clients paginating backwards
+ # will send pagination requests that trigger backfill at least twice
+ # using the most recent extremity before it gets removed (see below). We
+ # chose more than one times the limit in case of failure, but choosing a
+ # much larger factor will result in triggering a backfill request much
+ # earlier than necessary.
+ if current_depth - 2 * limit > max_depth:
+ logger.debug(
+ "Not backfilling as we don't need to. %d < %d - 2 * %d",
+ max_depth,
+ current_depth,
+ limit,
+ )
+ return False
+
+ logger.debug(
+ "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
+ room_id,
+ current_depth,
+ max_depth,
+ sorted_extremeties_tuple,
+ )
+
+ # We ignore extremities that have a greater depth than our current depth
+ # as:
+ # 1. we don't really care about getting events that have happened
+ # before our current position; and
+ # 2. we have likely previously tried and failed to backfill from that
+ # extremity, so to avoid getting "stuck" requesting the same
+ # backfill repeatedly we drop those extremities.
+ filtered_sorted_extremeties_tuple = [
+ t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
+ ]
+
+ # However, we need to check that the filtered extremities are non-empty.
+ # If they are empty then either we can a) bail or b) still attempt to
+ # backill. We opt to try backfilling anyway just in case we do get
+ # relevant events.
+ if filtered_sorted_extremeties_tuple:
+ sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
+
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
- if current_depth > max_depth:
- logger.debug(
- "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
- )
- return
-
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
@@ -2266,10 +2324,10 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets = await self.state_store.get_state_groups(
+ state_sets_d = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
- state_sets = list(state_sets.values())
+ state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
state_sets.append(state)
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
@@ -3060,7 +3118,8 @@ class FederationHandler(BaseHandler):
)
return result["max_stream_id"]
else:
- max_stream_id = await self.storage.persistence.persist_events(
+ assert self.storage.persistence
+ max_stream_token = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
@@ -3071,12 +3130,12 @@ class FederationHandler(BaseHandler):
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
- await self._notify_persisted_event(event, max_stream_id)
+ await self._notify_persisted_event(event, max_stream_token)
- return max_stream_id
+ return max_stream_token.stream
async def _notify_persisted_event(
- self, event: EventBase, max_stream_id: int
+ self, event: EventBase, max_stream_token: RoomStreamToken
) -> None:
"""Checks to see if notifier/pushers should be notified about the
event or not.
@@ -3102,9 +3161,14 @@ class FederationHandler(BaseHandler):
elif event.internal_metadata.is_outlier():
return
- event_stream_id = event.internal_metadata.stream_ordering
+ # the event has been persisted so it should have a stream ordering.
+ assert event.internal_metadata.stream_ordering
+
+ event_pos = PersistedEventPosition(
+ self._instance_name, event.internal_metadata.stream_ordering
+ )
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
+ event, event_pos, max_stream_token, extra_users=extra_users
)
async def _clean_room_for_join(self, room_id: str) -> None:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 44df567983..9684e60fc8 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -240,7 +240,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler):
def __init__(self, hs):
- super(GroupsLocalHandler, self).__init__(hs)
+ super().__init__(hs)
# Ensure attestations get renewed
hs.get_groups_attestation_renewer()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0ce6ddfbe4..bc3e9607ca 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -21,8 +21,6 @@ import logging
import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
-from twisted.internet.error import TimeoutError
-
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -30,6 +28,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
@@ -45,7 +44,7 @@ id_server_scheme = "https://"
class IdentityHandler(BaseHandler):
def __init__(self, hs):
- super(IdentityHandler, self).__init__(hs)
+ super().__init__(hs)
self.http_client = SimpleHttpClient(hs)
# We create a blacklisting instance of SimpleHttpClient for contacting identity
@@ -93,7 +92,7 @@ class IdentityHandler(BaseHandler):
try:
data = await self.http_client.get_json(url, query_params)
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.info(
@@ -173,7 +172,7 @@ class IdentityHandler(BaseHandler):
if e.code != 404 or not use_v2:
logger.error("3PID bind failed with Matrix error: %r", e)
raise e.to_synapse_error()
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
data = json_decoder.decode(e.msg) # XXX WAT?
@@ -273,7 +272,7 @@ class IdentityHandler(BaseHandler):
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(500, "Failed to contact identity server")
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
await self.store.remove_user_bound_threepid(
@@ -419,7 +418,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
async def requestMsisdnToken(
@@ -471,7 +470,7 @@ class IdentityHandler(BaseHandler):
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
assert self.hs.config.public_baseurl
@@ -553,7 +552,7 @@ class IdentityHandler(BaseHandler):
id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
body,
)
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
@@ -627,7 +626,7 @@ class IdentityHandler(BaseHandler):
# require or validate it. See the following for context:
# https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
return data["mxid"]
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except IOError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
@@ -655,7 +654,7 @@ class IdentityHandler(BaseHandler):
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
if not isinstance(hash_details, dict):
@@ -727,7 +726,7 @@ class IdentityHandler(BaseHandler):
},
headers=headers,
)
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except Exception as e:
logger.warning("Error when performing a v2 3pid lookup: %s", e)
@@ -823,7 +822,7 @@ class IdentityHandler(BaseHandler):
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
)
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
if e.code != 404:
@@ -841,7 +840,7 @@ class IdentityHandler(BaseHandler):
data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
- except TimeoutError:
+ except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index ba4828c713..39a85801c1 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
- super(InitialSyncHandler, self).__init__(hs)
+ super().__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
@@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler):
messages, time_now=time_now, as_client_event=as_client_event
)
),
- "start": start_token.to_string(),
- "end": end_token.to_string(),
+ "start": await start_token.to_string(self.store),
+ "end": await end_token.to_string(self.store),
}
d["state"] = await self._event_serializer.serialize_events(
@@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler):
],
"account_data": account_data_events,
"receipts": receipt,
- "end": now_token.to_string(),
+ "end": await now_token.to_string(self.store),
}
return ret
@@ -325,7 +325,8 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- stream_token = await self.store.get_stream_token_for_event(member_event_id)
+ leave_position = await self.store.get_position_for_event(member_event_id)
+ stream_token = leave_position.to_room_stream_token()
messages, token = await self.store.get_recent_events_for_room(
room_id, limit=limit, end_token=stream_token
@@ -347,8 +348,8 @@ class InitialSyncHandler(BaseHandler):
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
),
- "start": start_token.to_string(),
- "end": end_token.to_string(),
+ "start": await start_token.to_string(self.store),
+ "end": await end_token.to_string(self.store),
},
"state": (
await self._event_serializer.serialize_events(
@@ -446,8 +447,8 @@ class InitialSyncHandler(BaseHandler):
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
),
- "start": start_token.to_string(),
- "end": end_token.to_string(),
+ "start": await start_token.to_string(self.store),
+ "end": await end_token.to_string(self.store),
},
"state": state,
"presence": presence,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c13abd402d..5891939bb1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -636,59 +636,6 @@ class EventCreationHandler:
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
- async def send_nonmember_event(
- self,
- requester: Requester,
- event: EventBase,
- context: EventContext,
- ratelimit: bool = True,
- ignore_shadow_ban: bool = False,
- ) -> int:
- """
- Persists and notifies local clients and federation of an event.
-
- Args:
- requester: The requester sending the event.
- event: The event to send.
- context: The context of the event.
- ratelimit: Whether to rate limit this send.
- ignore_shadow_ban: True if shadow-banned users should be allowed to
- send this event.
-
- Return:
- The stream_id of the persisted event.
-
- Raises:
- ShadowBanError if the requester has been shadow-banned.
- """
- if event.type == EventTypes.Member:
- raise SynapseError(
- 500, "Tried to send member event through non-member codepath"
- )
-
- if not ignore_shadow_ban and requester.shadow_banned:
- # We randomly sleep a bit just to annoy the requester.
- await self.clock.sleep(random.randint(1, 10))
- raise ShadowBanError()
-
- user = UserID.from_string(event.sender)
-
- assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
-
- if event.is_state():
- prev_event = await self.deduplicate_state_event(event, context)
- if prev_event is not None:
- logger.info(
- "Not bothering to persist state event %s duplicated by %s",
- event.event_id,
- prev_event.event_id,
- )
- return await self.store.get_stream_id_for_event(prev_event.event_id)
-
- return await self.handle_new_client_event(
- requester=requester, event=event, context=context, ratelimit=ratelimit
- )
-
async def deduplicate_state_event(
self, event: EventBase, context: EventContext
) -> Optional[EventBase]:
@@ -729,7 +676,7 @@ class EventCreationHandler:
"""
Creates an event, then sends it.
- See self.create_event and self.send_nonmember_event.
+ See self.create_event and self.handle_new_client_event.
Args:
requester: The requester sending the event.
@@ -739,9 +686,19 @@ class EventCreationHandler:
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.
+ Returns:
+ The event, and its stream ordering (if state event deduplication happened,
+ the previous, duplicate event).
+
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
+
+ if event_dict["type"] == EventTypes.Member:
+ raise SynapseError(
+ 500, "Tried to send member event through non-member codepath"
+ )
+
if not ignore_shadow_ban and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
@@ -757,20 +714,27 @@ class EventCreationHandler:
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
+ assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
+ event.sender,
+ )
+
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
- stream_id = await self.send_nonmember_event(
- requester,
- event,
- context,
+ ev = await self.handle_new_client_event(
+ requester=requester,
+ event=event,
+ context=context,
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
- return event, stream_id
+
+ # we know it was persisted, so must have a stream ordering
+ assert ev.internal_metadata.stream_ordering
+ return ev, ev.internal_metadata.stream_ordering
@measure_func("create_new_client_event")
async def create_new_client_event(
@@ -844,8 +808,11 @@ class EventCreationHandler:
context: EventContext,
ratelimit: bool = True,
extra_users: List[UserID] = [],
- ) -> int:
- """Processes a new event. This includes checking auth, persisting it,
+ ignore_shadow_ban: bool = False,
+ ) -> EventBase:
+ """Processes a new event.
+
+ This includes deduplicating, checking auth, persisting,
notifying users, sending to remote servers, etc.
If called from a worker will hit out to the master process for final
@@ -858,10 +825,39 @@ class EventCreationHandler:
ratelimit
extra_users: Any extra users to notify about event
+ ignore_shadow_ban: True if shadow-banned users should be allowed to
+ send this event.
+
Return:
- The stream_id of the persisted event.
+ If the event was deduplicated, the previous, duplicate, event. Otherwise,
+ `event`.
+
+ Raises:
+ ShadowBanError if the requester has been shadow-banned.
"""
+ # we don't apply shadow-banning to membership events here. Invites are blocked
+ # higher up the stack, and we allow shadow-banned users to send join and leave
+ # events as normal.
+ if (
+ event.type != EventTypes.Member
+ and not ignore_shadow_ban
+ and requester.shadow_banned
+ ):
+ # We randomly sleep a bit just to annoy the requester.
+ await self.clock.sleep(random.randint(1, 10))
+ raise ShadowBanError()
+
+ if event.is_state():
+ prev_event = await self.deduplicate_state_event(event, context)
+ if prev_event is not None:
+ logger.info(
+ "Not bothering to persist state event %s duplicated by %s",
+ event.event_id,
+ prev_event.event_id,
+ )
+ return prev_event
+
if event.is_state() and (event.type, event.state_key) == (
EventTypes.Create,
"",
@@ -916,13 +912,13 @@ class EventCreationHandler:
)
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
- return stream_id
+ return event
stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- return stream_id
+ return event
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
@@ -1139,7 +1135,7 @@ class EventCreationHandler:
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
- event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
+ event_pos, max_stream_token = await self.storage.persistence.persist_event(
event, context=context
)
@@ -1150,7 +1146,7 @@ class EventCreationHandler:
def _notify():
try:
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
+ event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
@@ -1162,7 +1158,7 @@ class EventCreationHandler:
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
- return event_stream_id
+ return event_pos.stream
async def _bump_active_time(self, user: UserID) -> None:
try:
@@ -1183,54 +1179,7 @@ class EventCreationHandler:
)
for room_id in room_ids:
- # For each room we need to find a joined member we can use to send
- # the dummy event with.
-
- latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-
- members = await self.state.get_current_users_in_room(
- room_id, latest_event_ids=latest_event_ids
- )
- dummy_event_sent = False
- for user_id in members:
- if not self.hs.is_mine_id(user_id):
- continue
- requester = create_requester(user_id)
- try:
- event, context = await self.create_event(
- requester,
- {
- "type": "org.matrix.dummy_event",
- "content": {},
- "room_id": room_id,
- "sender": user_id,
- },
- prev_event_ids=latest_event_ids,
- )
-
- event.internal_metadata.proactively_send = False
-
- # Since this is a dummy-event it is OK if it is sent by a
- # shadow-banned user.
- await self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=False,
- ignore_shadow_ban=True,
- )
- dummy_event_sent = True
- break
- except ConsentNotGivenError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of consent. Will try another user" % (room_id, user_id)
- )
- except AuthError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of power. Will try another user" % (room_id, user_id)
- )
+ dummy_event_sent = await self._send_dummy_event_for_room(room_id)
if not dummy_event_sent:
# Did not find a valid user in the room, so remove from future attempts
@@ -1243,6 +1192,63 @@ class EventCreationHandler:
now = self.clock.time_msec()
self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
+ async def _send_dummy_event_for_room(self, room_id: str) -> bool:
+ """Attempt to send a dummy event for the given room.
+
+ Args:
+ room_id: room to try to send an event from
+
+ Returns:
+ True if a dummy event was successfully sent. False if no user was able
+ to send an event.
+ """
+
+ # For each room we need to find a joined member we can use to send
+ # the dummy event with.
+ latest_event_ids = await self.store.get_prev_events_for_room(room_id)
+ members = await self.state.get_current_users_in_room(
+ room_id, latest_event_ids=latest_event_ids
+ )
+ for user_id in members:
+ if not self.hs.is_mine_id(user_id):
+ continue
+ requester = create_requester(user_id)
+ try:
+ event, context = await self.create_event(
+ requester,
+ {
+ "type": "org.matrix.dummy_event",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+
+ event.internal_metadata.proactively_send = False
+
+ # Since this is a dummy-event it is OK if it is sent by a
+ # shadow-banned user.
+ await self.handle_new_client_event(
+ requester=requester,
+ event=event,
+ context=context,
+ ratelimit=False,
+ ignore_shadow_ban=True,
+ )
+ return True
+ except ConsentNotGivenError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of consent. Will try another user" % (room_id, user_id)
+ )
+ except AuthError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of power. Will try another user" % (room_id, user_id)
+ )
+ return False
+
def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4230dbaf99..05ac86e697 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -37,7 +37,7 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -96,6 +96,7 @@ class OidcHandler:
self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._user_profile_method = hs.config.oidc_user_profile_method # type: str
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
@@ -114,6 +115,7 @@ class OidcHandler:
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
+ self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
@@ -195,11 +197,11 @@ class OidcHandler:
% (m["response_types_supported"],)
)
- # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ # Ensure there's a userinfo endpoint to fetch from if it is required.
if self._uses_userinfo:
if m.get("userinfo_endpoint") is None:
raise ValueError(
- 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ 'provider has no "userinfo_endpoint", even though it is required'
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
@@ -219,8 +221,10 @@ class OidcHandler:
``access_token`` with the ``userinfo_endpoint``.
"""
- # Maybe that should be user-configurable and not inferred?
- return "openid" not in self._scopes
+ return (
+ "openid" not in self._scopes
+ or self._user_profile_method == "userinfo_endpoint"
+ )
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
@@ -706,6 +710,15 @@ class OidcHandler:
self._render_error(request, "mapping_error", str(e))
return
+ # Mapping providers might not have get_extra_attributes: only call this
+ # method if it exists.
+ extra_attributes = None
+ get_extra_attributes = getattr(
+ self._user_mapping_provider, "get_extra_attributes", None
+ )
+ if get_extra_attributes:
+ extra_attributes = await get_extra_attributes(userinfo, token)
+
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
@@ -713,7 +726,7 @@ class OidcHandler:
)
else:
await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
+ user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token(
@@ -849,7 +862,8 @@ class OidcHandler:
If we don't find the user that way, we should register the user,
mapping the localpart and the display name from the UserInfo.
- If a user already exists with the mxid we've mapped, raise an exception.
+ If a user already exists with the mxid we've mapped and allow_existing_users
+ is disabled, raise an exception.
Args:
userinfo: an object representing the user
@@ -905,21 +919,31 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
- user_id = UserID(localpart, self._hostname)
- if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
- # This mxid is taken
- raise MappingException(
- "mxid '{}' is already taken".format(user_id.to_string())
+ user_id = UserID(localpart, self._hostname).to_string()
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ if self._allow_existing_users:
+ if len(users) == 1:
+ registered_user_id = next(iter(users))
+ elif user_id in users:
+ registered_user_id = user_id
+ else:
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, list(users.keys())
+ )
+ )
+ else:
+ # This mxid is taken
+ raise MappingException("mxid '{}' is already taken".format(user_id))
+ else:
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
-
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
@@ -972,7 +996,7 @@ class OidcMappingProvider(Generic[C]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
- """Map a ``UserInfo`` objects into user attributes.
+ """Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
@@ -983,6 +1007,18 @@ class OidcMappingProvider(Generic[C]):
"""
raise NotImplementedError()
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ """Map a `UserInfo` object into additional attributes passed to the client during login.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ token: A dict with the tokens returned by the provider
+
+ Returns:
+ A dict containing additional attributes. Must be JSON serializable.
+ """
+ return {}
+
# Used to clear out "None" values in templates
def jinja_finalize(thing):
@@ -997,6 +1033,7 @@ class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str
localpart_template = attr.ib() # type: Template
display_name_template = attr.ib() # type: Optional[Template]
+ extra_attributes = attr.ib() # type: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,10 +1072,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
% (e,)
)
+ extra_attributes = {} # type Dict[str, Template]
+ if "extra_attributes" in config:
+ extra_attributes_config = config.get("extra_attributes") or {}
+ if not isinstance(extra_attributes_config, dict):
+ raise ConfigError(
+ "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
+ )
+
+ for key, value in extra_attributes_config.items():
+ try:
+ extra_attributes[key] = env.from_string(value)
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
+ % (key, e)
+ )
+
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
+ extra_attributes=extra_attributes,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1059,3 +1114,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)
+
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ extras = {} # type: Dict[str, str]
+ for key, template in self._config.extra_attributes.items():
+ try:
+ extras[key] = template.render(user=userinfo).strip()
+ except Exception as e:
+ # Log an error and skip this value (don't break login for this).
+ logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
+ return extras
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index d929a68f7d..2c2a633938 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -25,7 +25,7 @@ from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import Requester, RoomStreamToken
+from synapse.types import Requester
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -358,9 +358,9 @@ class PaginationHandler:
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
- max_topo = room_token.topological
+ curr_topo = room_token.topological
else:
- max_topo = await self.store.get_max_topological_token(
+ curr_topo = await self.store.get_current_topological_token(
room_id, room_token.stream
)
@@ -373,19 +373,18 @@ class PaginationHandler:
# case "JOIN" would have been returned.
assert member_event_id
- leave_token_str = await self.store.get_topological_token_for_event(
+ leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
- leave_token = RoomStreamToken.parse(leave_token_str)
assert leave_token.topological is not None
- if leave_token.topological < max_topo:
+ if leave_token.topological < curr_topo:
from_token = from_token.copy_and_replace(
"room_key", leave_token
)
await self.hs.get_handlers().federation_handler.maybe_backfill(
- room_id, max_topo
+ room_id, curr_topo, limit=pagin_config.limit,
)
to_room_key = None
@@ -414,8 +413,8 @@ class PaginationHandler:
if not events:
return {
"chunk": [],
- "start": from_token.to_string(),
- "end": next_token.to_string(),
+ "start": await from_token.to_string(self.store),
+ "end": await next_token.to_string(self.store),
}
state = None
@@ -443,8 +442,8 @@ class PaginationHandler:
events, time_now, as_client_event=as_client_event
)
),
- "start": from_token.to_string(),
- "end": next_token.to_string(),
+ "start": await from_token.to_string(self.store),
+ "end": await next_token.to_string(self.store),
}
if state:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 0cb8fad89a..5453e6dfc8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -44,7 +44,7 @@ class BaseProfileHandler(BaseHandler):
"""
def __init__(self, hs):
- super(BaseProfileHandler, self).__init__(hs)
+ super().__init__(hs)
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -369,7 +369,7 @@ class MasterProfileHandler(BaseProfileHandler):
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
- super(MasterProfileHandler, self).__init__(hs)
+ super().__init__(hs)
assert hs.config.worker_app is None
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index e3b528d271..c32f314a1c 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
class ReadMarkerHandler(BaseHandler):
def __init__(self, hs):
- super(ReadMarkerHandler, self).__init__(hs)
+ super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.read_marker_linearizer = Linearizer(name="read_marker")
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index bdd8e52edd..7225923757 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
- super(ReceiptsHandler, self).__init__(hs)
+ super().__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cde2dbca92..538f4b2a61 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -42,7 +42,7 @@ class RegistrationHandler(BaseHandler):
Args:
hs (synapse.server.HomeServer):
"""
- super(RegistrationHandler, self).__init__(hs)
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index eeade6ad3f..d0530a446c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -70,7 +70,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
- super(RoomCreationHandler, self).__init__(hs)
+ super().__init__(hs)
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -185,6 +185,7 @@ class RoomCreationHandler(BaseHandler):
ShadowBanError if the requester is shadow-banned.
"""
user_id = requester.user.to_string()
+ assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
# start by allocating a new room id
r = await self.store.get_room(old_room_id)
@@ -229,8 +230,8 @@ class RoomCreationHandler(BaseHandler):
)
# now send the tombstone
- await self.event_creation_handler.send_nonmember_event(
- requester, tombstone_event, tombstone_context
+ await self.event_creation_handler.handle_new_client_event(
+ requester=requester, event=tombstone_event, context=tombstone_context,
)
old_room_state = await tombstone_context.get_current_state_ids()
@@ -681,6 +682,15 @@ class RoomCreationHandler(BaseHandler):
creator_id=user_id, is_public=is_public, room_version=room_version,
)
+ # Check whether this visibility value is blocked by a third party module
+ allowed_by_third_party_rules = await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
+ )
+ if not allowed_by_third_party_rules:
+ raise SynapseError(403, "Room visibility value not allowed.")
+
directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
await directory_handler.create_association(
@@ -962,8 +972,6 @@ class RoomCreationHandler(BaseHandler):
try:
random_string = stringutils.random_string(18)
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
- if isinstance(gen_room_id, bytes):
- gen_room_id = gen_room_id.decode("utf-8")
await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
@@ -1077,11 +1085,13 @@ class RoomContextHandler:
# the token, which we replace.
token = StreamToken.START
- results["start"] = token.copy_and_replace(
+ results["start"] = await token.copy_and_replace(
"room_key", results["start"]
- ).to_string()
+ ).to_string(self.store)
- results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
+ results["end"] = await token.copy_and_replace(
+ "room_key", results["end"]
+ ).to_string(self.store)
return results
@@ -1134,14 +1144,14 @@ class RoomEventSource:
events[:] = events[:limit]
if events:
- end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
+ end_key = events[-1].internal_metadata.after
else:
end_key = to_key
return (events, end_key)
def get_current_key(self) -> RoomStreamToken:
- return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
+ return self.store.get_room_max_token()
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5dd7b28391..4a13c8e912 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -38,7 +38,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
def __init__(self, hs):
- super(RoomListHandler, self).__init__(hs)
+ super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(hs, "room_list")
self.remote_response_cache = ResponseCache(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 789cca5c3e..e3aae2375b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -52,14 +52,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class RoomMemberHandler:
+class RoomMemberHandler(metaclass=abc.ABCMeta):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
@@ -205,15 +203,6 @@ class RoomMemberHandler:
require_consent=require_consent,
)
- # Check if this event matches the previous membership event for the user.
- duplicate = await self.event_creation_handler.deduplicate_state_event(
- event, context
- )
- if duplicate is not None:
- # Discard the new event since this membership change is a no-op.
- _, stream_id = await self.store.get_event_ordering(duplicate.event_id)
- return duplicate.event_id, stream_id
-
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -238,7 +227,7 @@ class RoomMemberHandler:
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
- stream_id = await self.event_creation_handler.handle_new_client_event(
+ result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
@@ -248,7 +237,9 @@ class RoomMemberHandler:
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target, room_id)
- return event.event_id, stream_id
+ # we know it was persisted, so should have a stream ordering
+ assert result_event.internal_metadata.stream_ordering
+ return result_event.event_id, result_event.internal_metadata.stream_ordering
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
@@ -264,7 +255,7 @@ class RoomMemberHandler:
user_account_data, _ = await self.store.get_account_data_for_user(user_id)
# Copy direct message state if applicable
- direct_rooms = user_account_data.get("m.direct", {})
+ direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
# Check which key this room is under
if isinstance(direct_rooms, dict):
@@ -275,7 +266,7 @@ class RoomMemberHandler:
# Save back to user's m.direct account data
await self.store.add_account_data_for_user(
- user_id, "m.direct", direct_rooms
+ user_id, AccountDataTypes.DIRECT, direct_rooms
)
break
@@ -458,12 +449,12 @@ class RoomMemberHandler:
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
- _, stream_id = await self.store.get_event_ordering(
- old_state.event_id
- )
+ # duplicate event.
+ # we know it was persisted, so must have a stream ordering.
+ assert old_state.internal_metadata.stream_ordering
return (
old_state.event_id,
- stream_id,
+ old_state.internal_metadata.stream_ordering,
)
if old_membership in ["ban", "leave"] and action == "kick":
@@ -676,7 +667,7 @@ class RoomMemberHandler:
async def send_membership_event(
self,
- requester: Requester,
+ requester: Optional[Requester],
event: EventBase,
context: EventContext,
ratelimit: bool = True,
@@ -706,12 +697,6 @@ class RoomMemberHandler:
else:
requester = types.create_requester(target_user)
- prev_event = await self.event_creation_handler.deduplicate_state_event(
- event, context
- )
- if prev_event is not None:
- return
-
prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
@@ -1219,10 +1204,13 @@ class RoomMemberMasterHandler(RoomMemberHandler):
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
- stream_id = await self.event_creation_handler.handle_new_client_event(
+ result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)],
)
- return event.event_id, stream_id
+ # we know it was persisted, so must have a stream ordering
+ assert result_event.internal_metadata.stream_ordering
+
+ return result_event.event_id, result_event.internal_metadata.stream_ordering
async def _remote_knock(
self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict,
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index e138b6bad4..cd9b3ef629 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class RoomMemberWorkerHandler(RoomMemberHandler):
def __init__(self, hs):
- super(RoomMemberWorkerHandler, self).__init__(hs)
+ super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index d58f9788c5..e9402e6e2e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
def __init__(self, hs):
- super(SearchHandler, self).__init__(hs)
+ super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -362,13 +362,13 @@ class SearchHandler(BaseHandler):
self.storage, user.to_string(), res["events_after"]
)
- res["start"] = now_token.copy_and_replace(
+ res["start"] = await now_token.copy_and_replace(
"room_key", res["start"]
- ).to_string()
+ ).to_string(self.store)
- res["end"] = now_token.copy_and_replace(
+ res["end"] = await now_token.copy_and_replace(
"room_key", res["end"]
- ).to_string()
+ ).to_string(self.store)
if include_profile:
senders = {
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..a5d67f828f 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -27,7 +27,7 @@ class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
def __init__(self, hs):
- super(SetPasswordHandler, self).__init__(hs)
+ super().__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 8393f4a439..7e69e8eaa8 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -50,7 +50,7 @@ class StatsHandler:
# Guard to ensure we only process deltas one at a time
self._is_processing = False
- if hs.config.stats_enabled:
+ if self.stats_enabled and hs.config.run_background_tasks:
self.notifier.add_replication_callback(self.notify_new_event)
# We kick this off so that we don't have to wait for a change before
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9b3a4f638b..6fb8332f93 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
import attr
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
from synapse.logging.context import current_context
@@ -87,7 +87,7 @@ class SyncConfig:
class TimelineBatch:
prev_batch = attr.ib(type=StreamToken)
events = attr.ib(type=List[EventBase])
- limited = attr.ib(bool)
+ limited = attr.ib(type=bool)
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -201,6 +201,8 @@ class SyncResult:
device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device
+ device_unused_fallback_key_types: List of key types that have an unused fallback
+ key
groups: Group updates, if any
"""
@@ -213,6 +215,7 @@ class SyncResult:
to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists)
device_one_time_keys_count = attr.ib(type=JsonDict)
+ device_unused_fallback_key_types = attr.ib(type=List[str])
groups = attr.ib(type=Optional[GroupsSyncResult])
def __bool__(self) -> bool:
@@ -457,8 +460,13 @@ class SyncHandler:
recents = []
if not limited or block_all_timeline:
+ prev_batch_token = now_token
+ if recents:
+ room_key = recents[0].internal_metadata.before
+ prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+
return TimelineBatch(
- events=recents, prev_batch=now_token, limited=False
+ events=recents, prev_batch=prev_batch_token, limited=False
)
filtering_factor = 2
@@ -519,7 +527,7 @@ class SyncHandler:
if len(recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
- room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
+ room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
@@ -967,7 +975,7 @@ class SyncHandler:
raise NotImplementedError()
else:
joined_room_ids = await self.get_rooms_for_user_at(
- user_id, now_token.room_stream_id
+ user_id, now_token.room_key
)
sync_result_builder = SyncResultBuilder(
sync_config,
@@ -1014,10 +1022,14 @@ class SyncHandler:
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict
+ unused_fallback_key_types = [] # type: List[str]
if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
+ unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
+ user_id, device_id
+ )
logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
@@ -1041,6 +1053,7 @@ class SyncHandler:
device_lists=device_lists,
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
+ device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
@@ -1378,13 +1391,16 @@ class SyncHandler:
return set(), set(), set(), set()
ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id=user_id
+ AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
)
+ # If there is ignored users account data and it matches the proper type,
+ # then use it.
+ ignored_users = frozenset() # type: FrozenSet[str]
if ignored_account_data:
- ignored_users = ignored_account_data.get("ignored_users", {}).keys()
- else:
- ignored_users = frozenset()
+ ignored_users_data = ignored_account_data.get("ignored_users", {})
+ if isinstance(ignored_users_data, dict):
+ ignored_users = frozenset(ignored_users_data.keys())
if since_token:
room_changes = await self._get_rooms_changed(
@@ -1478,7 +1494,7 @@ class SyncHandler:
return False
async def _get_rooms_changed(
- self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+ self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
"""Gets the the changes that have happened since the last sync.
"""
@@ -1595,16 +1611,24 @@ class SyncHandler:
if leave_events:
leave_event = leave_events[-1]
- leave_stream_token = await self.store.get_stream_token_for_event(
+ leave_position = await self.store.get_position_for_event(
leave_event.event_id
)
- leave_token = since_token.copy_and_replace(
- "room_key", leave_stream_token
- )
- if since_token and since_token.is_after(leave_token):
+ # If the leave event happened before the since token then we
+ # bail.
+ if since_token and not leave_position.persisted_after(
+ since_token.room_key
+ ):
continue
+ # We can safely convert the position of the leave event into a
+ # stream token as it'll only be used in the context of this
+ # room. (c.f. the docstring of `to_room_stream_token`).
+ leave_token = since_token.copy_and_replace(
+ "room_key", leave_position.to_room_stream_token()
+ )
+
# If this is an out of band message, like a remote invite
# rejection, we include it in the recents batch. Otherwise, we
# let _load_filtered_recents handle fetching the correct
@@ -1682,7 +1706,7 @@ class SyncHandler:
return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
async def _get_all_rooms(
- self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+ self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
"""Returns entries for all rooms for the user.
@@ -1756,7 +1780,7 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
- ignored_users: Set[str],
+ ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
@@ -1916,7 +1940,7 @@ class SyncHandler:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
async def get_rooms_for_user_at(
- self, user_id: str, stream_ordering: int
+ self, user_id: str, room_key: RoomStreamToken
) -> FrozenSet[str]:
"""Get set of joined rooms for a user at the given stream ordering.
@@ -1942,15 +1966,15 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
- for room_id, membership_stream_ordering in joined_rooms:
- if membership_stream_ordering <= stream_ordering:
+ for room_id, event_pos in joined_rooms:
+ if not event_pos.persisted_after(room_key):
joined_room_ids.add(room_id)
continue
logger.info("User joined room after current token: %s", room_id)
extrems = await self.store.get_forward_extremeties_for_room(
- room_id, stream_ordering
+ room_id, event_pos.stream
)
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index e21f8dbc58..79393c8829 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -37,7 +37,7 @@ class UserDirectoryHandler(StateDeltasHandler):
"""
def __init__(self, hs):
- super(UserDirectoryHandler, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 3880ce0d94..59b01b812c 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -16,8 +16,6 @@
import re
from twisted.internet import task
-from twisted.internet.defer import CancelledError
-from twisted.python import failure
from twisted.web.client import FileBodyProducer
from synapse.api.errors import SynapseError
@@ -26,19 +24,8 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
- def __init__(self):
- super(RequestTimedOutError, self).__init__(504, "Timed out")
-
-
-def cancelled_to_request_timed_out_error(value, timeout):
- """Turns CancelledErrors into RequestTimedOutErrors.
-
- For use with async.add_timeout_to_deferred
- """
- if isinstance(value, failure.Failure):
- value.trap(CancelledError)
- raise RequestTimedOutError()
- return value
+ def __init__(self, msg):
+ super().__init__(504, msg)
ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 13fcab3378..8324632cb6 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -13,10 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import urllib
from io import BytesIO
+from typing import (
+ Any,
+ BinaryIO,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
import treq
from canonicaljson import encode_canonical_json
@@ -26,7 +37,7 @@ from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
-from twisted.internet import defer, protocol, ssl
+from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IResolutionReceiver,
@@ -34,16 +45,18 @@ from twisted.internet.interfaces import (
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
-from twisted.web.client import Agent, HTTPConnectionPool, readBody
+from twisted.web.client import (
+ Agent,
+ HTTPConnectionPool,
+ ResponseNeverReceived,
+ readBody,
+)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
-from synapse.http import (
- QuieterFileBodyProducer,
- cancelled_to_request_timed_out_error,
- redact_uri,
-)
+from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
@@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
"synapse_http_client_responses", "", ["method", "code"]
)
+# the type of the headers list, to be passed to the t.w.h.Headers.
+# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
+# we simplify.
+RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
+
+# the value actually has to be a List, but List is invariant so we can't specify that
+# the entries can either be Lists or bytes.
+RawHeaderValue = Sequence[Union[str, bytes]]
+
+# the type of the query params, to be passed into `urlencode`
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
"""
@@ -285,16 +311,27 @@ class SimpleHttpClient:
ip_blacklist=self._ip_blacklist,
)
- async def request(self, method, uri, data=None, headers=None):
+ async def request(
+ self,
+ method: str,
+ uri: str,
+ data: Optional[bytes] = None,
+ headers: Optional[Headers] = None,
+ ) -> IResponse:
"""
Args:
- method (str): HTTP method to use.
- uri (str): URI to query.
- data (bytes): Data to send in the request body, if applicable.
- headers (t.w.http_headers.Headers): Request headers.
+ method: HTTP method to use.
+ uri: URI to query.
+ data: Data to send in the request body, if applicable.
+ headers: Request headers.
+
+ Returns:
+ Response object, once the headers have been read.
+
+ Raises:
+ RequestTimedOutError if the request times out before the headers are read
+
"""
- # A small wrapper around self.agent.request() so we can easily attach
- # counters to it
outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this)
@@ -323,13 +360,17 @@ class SimpleHttpClient:
data=body_producer,
headers=headers,
**self._extra_treq_args
- )
+ ) # type: defer.Deferred
+
+ # we use our own timeout mechanism rather than treq's as a workaround
+ # for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
- request_deferred,
- 60,
- self.hs.get_reactor(),
- cancelled_to_request_timed_out_error,
+ request_deferred, 60, self.hs.get_reactor(),
)
+
+ # turn timeouts into RequestTimedOutErrors
+ request_deferred.addErrback(_timeout_to_request_timed_out_error)
+
response = await make_deferred_yieldable(request_deferred)
incoming_responses_counter.labels(method, response.code).inc()
@@ -353,18 +394,26 @@ class SimpleHttpClient:
set_tag("error_reason", e.args[0])
raise
- async def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ async def post_urlencoded_get_json(
+ self,
+ uri: str,
+ args: Mapping[str, Union[str, List[str]]] = {},
+ headers: Optional[RawHeaders] = None,
+ ) -> Any:
"""
Args:
- uri (str):
- args (dict[str, str|List[str]]): query params
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: uri to query
+ args: parameters to be url-encoded in the body
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutError: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -398,19 +447,24 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def post_json_get_json(self, uri, post_json, headers=None):
+ async def post_json_get_json(
+ self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
+ ) -> Any:
"""
Args:
- uri (str):
- post_json (object):
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: URI to query.
+ post_json: request body, to be encoded as json
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutError: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -440,21 +494,22 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_json(self, uri, args={}, headers=None):
- """ Gets some json from the given URI.
+ async def get_json(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ ) -> Any:
+ """Gets some json from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query string
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutError: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -466,22 +521,27 @@ class SimpleHttpClient:
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
- async def put_json(self, uri, json_body, args={}, headers=None):
- """ Puts some json to the given URI.
+ async def put_json(
+ self,
+ uri: str,
+ json_body: Any,
+ args: QueryParams = {},
+ headers: RawHeaders = None,
+ ) -> Any:
+ """Puts some json to the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- json_body (dict): The JSON to put in the HTTP body,
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ json_body: The JSON to put in the HTTP body,
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutError: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -513,21 +573,23 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_raw(self, uri, args={}, headers=None):
- """ Gets raw text from the given URI.
+ async def get_raw(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ ) -> bytes:
+ """Gets raw text from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get a 2xx HTTP response, with the
HTTP body as bytes.
Raises:
+ RequestTimedOutError: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
@@ -552,16 +614,29 @@ class SimpleHttpClient:
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
- async def get_file(self, url, output_stream, max_size=None, headers=None):
+ async def get_file(
+ self,
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
"""GETs a file from a given URL
Args:
- url (str): The URL to GET
- output_stream (file): File to write the response body to.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ url: The URL to GET
+ output_stream: File to write the response body to.
+ headers: A map from header name to a list of values for that header
Returns:
- A (int,dict,string,int) tuple of the file length, dict of the response
+ A tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
+
+ Raises:
+ RequestTimedOutError: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
+ SynapseError: if the response is not a 2xx, the remote file is too large, or
+ another exception happens during the download.
"""
actual_headers = {b"User-Agent": [self.user_agent]}
@@ -609,6 +684,18 @@ class SimpleHttpClient:
)
+def _timeout_to_request_timed_out_error(f: Failure):
+ if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
+ # The TCP connection has its own timeout (set by the 'connectTimeout' param
+ # on the Agent), which raises twisted_error.TimeoutError exception.
+ raise RequestTimedOutError("Timeout connecting to remote server")
+ elif f.check(defer.TimeoutError, ResponseNeverReceived):
+ # this one means that we hit our overall timeout on the request
+ raise RequestTimedOutError("Timeout waiting for response from remote server")
+
+ return f
+
+
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3c86cbc546..c23a4d7c0c 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -171,7 +171,7 @@ async def _handle_json_response(
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = await make_deferred_yieldable(d)
- except TimeoutError as e:
+ except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
request.txn_id,
@@ -473,8 +473,6 @@ class MatrixFederationHttpClient:
)
response = await request_deferred
- except TimeoutError as e:
- raise RequestSendFailed(e, can_retry=True) from e
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
@@ -657,10 +655,14 @@ class MatrixFederationHttpClient:
long_retries (bool): whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response headers
- (including connecting to the server), *for each attempt*.
+ timeout (int|None): number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
+ Note that we may make several attempts to send the request; this
+ timeout applies to the time spent waiting for response headers for
+ *each* attempt (including connection time) as well as the time spent
+ reading the response body after a 200 response.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
@@ -706,8 +708,13 @@ class MatrixFederationHttpClient:
timeout=timeout,
)
+ if timeout is not None:
+ _sec_timeout = timeout / 1000
+ else:
+ _sec_timeout = self.default_timeout
+
body = await _handle_json_response(
- self.reactor, self.default_timeout, request, response, start_ms
+ self.reactor, _sec_timeout, request, response, start_ms
)
return body
@@ -736,10 +743,14 @@ class MatrixFederationHttpClient:
long_retries (bool): whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response headers
- (including connecting to the server), *for each attempt*.
+ timeout (int|None): number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
+ Note that we may make several attempts to send the request; this
+ timeout applies to the time spent waiting for response headers for
+ *each* attempt (including connection time) as well as the time spent
+ reading the response body after a 200 response.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
@@ -803,10 +814,14 @@ class MatrixFederationHttpClient:
args (dict|None): A dictionary used to create query strings, defaults to
None.
- timeout (int|None): number of milliseconds to wait for the response headers
- (including connecting to the server), *for each attempt*.
+ timeout (int|None): number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
+ Note that we may make several attempts to send the request; this
+ timeout applies to the time spent waiting for response headers for
+ *each* attempt (including connection time) as well as the time spent
+ reading the response body after a 200 response.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
@@ -842,8 +857,13 @@ class MatrixFederationHttpClient:
timeout=timeout,
)
+ if timeout is not None:
+ _sec_timeout = timeout / 1000
+ else:
+ _sec_timeout = self.default_timeout
+
body = await _handle_json_response(
- self.reactor, self.default_timeout, request, response, start_ms
+ self.reactor, _sec_timeout, request, response, start_ms
)
return body
@@ -867,10 +887,14 @@ class MatrixFederationHttpClient:
long_retries (bool): whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response headers
- (including connecting to the server), *for each attempt*.
+ timeout (int|None): number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
+ Note that we may make several attempts to send the request; this
+ timeout applies to the time spent waiting for response headers for
+ *each* attempt (including connection time) as well as the time spent
+ reading the response body after a 200 response.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
@@ -902,8 +926,13 @@ class MatrixFederationHttpClient:
ignore_backoff=ignore_backoff,
)
+ if timeout is not None:
+ _sec_timeout = timeout / 1000
+ else:
+ _sec_timeout = self.default_timeout
+
body = await _handle_json_response(
- self.reactor, self.default_timeout, request, response, start_ms
+ self.reactor, _sec_timeout, request, response, start_ms
)
return body
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 332da02a8d..e32d3f43e0 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -44,8 +44,11 @@ class ProxyAgent(_AgentBase):
`BrowserLikePolicyForHTTPS`, so unless you have special
requirements you can leave this as-is.
- connectTimeout (float): The amount of time that this Agent will wait
- for the peer to accept a connection.
+ connectTimeout (Optional[float]): The amount of time that this Agent will wait
+ for the peer to accept a connection, in seconds. If 'None',
+ HostnameEndpoint's default (30s) will be used.
+
+ This is used for connections to both proxies and destination servers.
bindAddress (bytes): The local address for client sockets to bind to.
@@ -108,6 +111,15 @@ class ProxyAgent(_AgentBase):
Returns:
Deferred[IResponse]: completes when the header of the response has
been received (regardless of the response status code).
+
+ Can fail with:
+ SchemeNotSupported: if the uri is not http or https
+
+ twisted.internet.error.TimeoutError if the server we are connecting
+ to (proxy or destination) does not accept a connection before
+ connectTimeout.
+
+ ... other things too.
"""
uri = uri.strip()
if not _VALID_URI.match(uri):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 996a31a9ec..00b98af3d4 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -257,7 +257,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
@@ -406,7 +406,7 @@ class JsonResource(DirectServeJsonResource):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
@@ -651,6 +651,11 @@ def respond_with_json_bytes(
Returns:
twisted.web.server.NOT_DONE_YET if the request is still active.
"""
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 144506c8f2..0fc2ea609e 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import os.path
import sys
@@ -89,14 +88,7 @@ class LogContextObserver:
context = current_context()
# Copy the context information to the log event.
- if context is not None:
- context.copy_to_twisted_log_entry(event)
- else:
- # If there's no logging context, not even the root one, we might be
- # starting up or it might be from non-Synapse code. Log it as if it
- # came from the root logger.
- event["request"] = None
- event["scope"] = None
+ context.copy_to_twisted_log_entry(event)
self.observer(event)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2e282d9d67..ca0c774cc5 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -65,6 +65,11 @@ except Exception:
return None
+# a hook which can be set during testing to assert that we aren't abusing logcontexts.
+def logcontext_error(msg: str):
+ logger.warning(msg)
+
+
# get an id for the current thread.
#
# threading.get_ident doesn't actually return an OS-level tid, and annoyingly,
@@ -330,10 +335,9 @@ class LoggingContext:
"""Enters this logging context into thread local storage"""
old_context = set_current_context(self)
if self.previous_context != old_context:
- logger.warning(
- "Expected previous context %r, found %r",
- self.previous_context,
- old_context,
+ logcontext_error(
+ "Expected previous context %r, found %r"
+ % (self.previous_context, old_context,)
)
return self
@@ -346,10 +350,10 @@ class LoggingContext:
current = set_current_context(self.previous_context)
if current is not self:
if current is SENTINEL_CONTEXT:
- logger.warning("Expected logging context %s was lost", self)
+ logcontext_error("Expected logging context %s was lost" % (self,))
else:
- logger.warning(
- "Expected logging context %s but found %s", self, current
+ logcontext_error(
+ "Expected logging context %s but found %s" % (self, current)
)
# the fact that we are here suggests that the caller thinks that everything
@@ -387,16 +391,16 @@ class LoggingContext:
support getrusuage.
"""
if get_thread_id() != self.main_thread:
- logger.warning("Started logcontext %s on different thread", self)
+ logcontext_error("Started logcontext %s on different thread" % (self,))
return
if self.finished:
- logger.warning("Re-starting finished log context %s", self)
+ logcontext_error("Re-starting finished log context %s" % (self,))
# If we haven't already started record the thread resource usage so
# far
if self.usage_start:
- logger.warning("Re-starting already-active log context %s", self)
+ logcontext_error("Re-starting already-active log context %s" % (self,))
else:
self.usage_start = rusage
@@ -414,7 +418,7 @@ class LoggingContext:
try:
if get_thread_id() != self.main_thread:
- logger.warning("Stopped logcontext %s on different thread", self)
+ logcontext_error("Stopped logcontext %s on different thread" % (self,))
return
if not rusage:
@@ -422,9 +426,9 @@ class LoggingContext:
# Record the cpu used since we started
if not self.usage_start:
- logger.warning(
- "Called stop on logcontext %s without recording a start rusage",
- self,
+ logcontext_error(
+ "Called stop on logcontext %s without recording a start rusage"
+ % (self,)
)
return
@@ -584,14 +588,13 @@ class PreserveLoggingContext:
if context != self._new_context:
if not context:
- logger.warning(
- "Expected logging context %s was lost", self._new_context
+ logcontext_error(
+ "Expected logging context %s was lost" % (self._new_context,)
)
else:
- logger.warning(
- "Expected logging context %s but found %s",
- self._new_context,
- context,
+ logcontext_error(
+ "Expected logging context %s but found %s"
+ % (self._new_context, context,)
)
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
index d736ad5b9b..11f60a77f7 100644
--- a/synapse/logging/formatter.py
+++ b/synapse/logging/formatter.py
@@ -30,7 +30,7 @@ class LogFormatter(logging.Formatter):
"""
def __init__(self, *args, **kwargs):
- super(LogFormatter, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
def formatException(self, ei):
sio = StringIO()
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 026854b4c7..7b9c657456 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -107,7 +107,7 @@ class _LogContextScope(Scope):
finish_on_close (Boolean):
if True finish the span when the scope is closed
"""
- super(_LogContextScope, self).__init__(manager, span)
+ super().__init__(manager, span)
self.logcontext = logcontext
self._finish_on_close = finish_on_close
self._enter_logcontext = enter_logcontext
@@ -120,9 +120,9 @@ class _LogContextScope(Scope):
def __exit__(self, type, value, traceback):
if type == twisted.internet.defer._DefGen_Return:
- super(_LogContextScope, self).__exit__(None, None, None)
+ super().__exit__(None, None, None)
else:
- super(_LogContextScope, self).__exit__(type, value, traceback)
+ super().__exit__(type, value, traceback)
if self._enter_logcontext:
self.logcontext.__exit__(type, value, traceback)
else: # the logcontext existed before the creation of the scope
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index a1f7ca3449..b8d2a8e8a9 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -15,6 +15,7 @@
import functools
import gc
+import itertools
import logging
import os
import platform
@@ -27,8 +28,8 @@ from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import (
REGISTRY,
CounterMetricFamily,
+ GaugeHistogramMetricFamily,
GaugeMetricFamily,
- HistogramMetricFamily,
)
from twisted.internet import reactor
@@ -46,7 +47,7 @@ logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
running_on_pypy = platform.python_implementation() == "PyPy"
-all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]]
+all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]]
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@@ -205,63 +206,83 @@ class InFlightGauge:
all_gauges[self.name] = self
-@attr.s(slots=True, hash=True)
-class BucketCollector:
- """
- Like a Histogram, but allows buckets to be point-in-time instead of
- incrementally added to.
+class GaugeBucketCollector:
+ """Like a Histogram, but the buckets are Gauges which are updated atomically.
- Args:
- name (str): Base name of metric to be exported to Prometheus.
- data_collector (callable -> dict): A synchronous callable that
- returns a dict mapping bucket to number of items in the
- bucket. If these buckets are not the same as the buckets
- given to this class, they will be remapped into them.
- buckets (list[float]): List of floats/ints of the buckets to
- give to Prometheus. +Inf is ignored, if given.
+ The data is updated by calling `update_data` with an iterable of measurements.
+ We assume that the data is updated less frequently than it is reported to
+ Prometheus, and optimise for that case.
"""
- name = attr.ib()
- data_collector = attr.ib()
- buckets = attr.ib()
+ __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric")
- def collect(self):
+ def __init__(
+ self,
+ name: str,
+ documentation: str,
+ buckets: Iterable[float],
+ registry=REGISTRY,
+ ):
+ """
+ Args:
+ name: base name of metric to be exported to Prometheus. (a _bucket suffix
+ will be added.)
+ documentation: help text for the metric
+ buckets: The top bounds of the buckets to report
+ registry: metric registry to register with
+ """
+ self._name = name
+ self._documentation = documentation
- # Fetch the data -- this must be synchronous!
- data = self.data_collector()
+ # the tops of the buckets
+ self._bucket_bounds = [float(b) for b in buckets]
+ if self._bucket_bounds != sorted(self._bucket_bounds):
+ raise ValueError("Buckets not in sorted order")
- buckets = {} # type: Dict[float, int]
+ if self._bucket_bounds[-1] != float("inf"):
+ self._bucket_bounds.append(float("inf"))
- res = []
- for x in data.keys():
- for i, bound in enumerate(self.buckets):
- if x <= bound:
- buckets[bound] = buckets.get(bound, 0) + data[x]
+ self._metric = self._values_to_metric([])
+ registry.register(self)
- for i in self.buckets:
- res.append([str(i), buckets.get(i, 0)])
+ def collect(self):
+ yield self._metric
- res.append(["+Inf", sum(data.values())])
+ def update_data(self, values: Iterable[float]):
+ """Update the data to be reported by the metric
- metric = HistogramMetricFamily(
- self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items())
+ The existing data is cleared, and each measurement in the input is assigned
+ to the relevant bucket.
+ """
+ self._metric = self._values_to_metric(values)
+
+ def _values_to_metric(self, values: Iterable[float]) -> GaugeHistogramMetricFamily:
+ total = 0.0
+ bucket_values = [0 for _ in self._bucket_bounds]
+
+ for v in values:
+ # assign each value to a bucket
+ for i, bound in enumerate(self._bucket_bounds):
+ if v <= bound:
+ bucket_values[i] += 1
+ break
+
+ # ... and increment the sum
+ total += v
+
+ # now, aggregate the bucket values so that they count the number of entries in
+ # that bucket or below.
+ accumulated_values = itertools.accumulate(bucket_values)
+
+ return GaugeHistogramMetricFamily(
+ self._name,
+ self._documentation,
+ buckets=list(
+ zip((str(b) for b in self._bucket_bounds), accumulated_values)
+ ),
+ gsum_value=total,
)
- yield metric
-
- def __attrs_post_init__(self):
- self.buckets = [float(x) for x in self.buckets if x != "+Inf"]
- if self.buckets != sorted(self.buckets):
- raise ValueError("Buckets not sorted")
-
- self.buckets = tuple(self.buckets)
-
- if self.name in all_gauges.keys():
- logger.warning("%s already registered, reregistering" % (self.name,))
- REGISTRY.unregister(all_gauges.pop(self.name))
-
- REGISTRY.register(self)
- all_gauges[self.name] = self
#
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 4304c60d56..734271e765 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -24,9 +24,9 @@ expect, and the newer "best practice" version of the up-to-date official client.
import math
import threading
-from collections import namedtuple
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
+from typing import Dict, List
from urllib.parse import parse_qs, urlparse
from prometheus_client import REGISTRY
@@ -35,14 +35,6 @@ from twisted.web.resource import Resource
from synapse.util import caches
-try:
- from prometheus_client.samples import Sample
-except ImportError:
- Sample = namedtuple( # type: ignore[no-redef] # noqa
- "Sample", ["name", "labels", "value", "timestamp", "exemplar"]
- )
-
-
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
@@ -93,17 +85,6 @@ def sample_line(line, name):
)
-def nameify_sample(sample):
- """
- If we get a prometheus_client<0.4.0 sample as a tuple, transform it into a
- namedtuple which has the names we expect.
- """
- if not isinstance(sample, Sample):
- sample = Sample(*sample, None, None)
-
- return sample
-
-
def generate_latest(registry, emit_help=False):
# Trigger the cache metrics to be rescraped, which updates the common
@@ -144,16 +125,33 @@ def generate_latest(registry, emit_help=False):
)
)
output.append("# TYPE {0} {1}\n".format(mname, mtype))
- for sample in map(nameify_sample, metric.samples):
- # Get rid of the OpenMetrics specific samples
+
+ om_samples = {} # type: Dict[str, List[str]]
+ for s in metric.samples:
for suffix in ["_created", "_gsum", "_gcount"]:
- if sample.name.endswith(suffix):
+ if s.name == metric.name + suffix:
+ # OpenMetrics specific sample, put in a gauge at the end.
+ # (these come from gaugehistograms which don't get renamed,
+ # so no need to faff with mnewname)
+ om_samples.setdefault(suffix, []).append(sample_line(s, s.name))
break
else:
- newname = sample.name.replace(mnewname, mname)
+ newname = s.name.replace(mnewname, mname)
if ":" in newname and newname.endswith("_total"):
newname = newname[: -len("_total")]
- output.append(sample_line(sample, newname))
+ output.append(sample_line(s, newname))
+
+ for suffix, lines in sorted(om_samples.items()):
+ if emit_help:
+ output.append(
+ "# HELP {0}{1} {2}\n".format(
+ metric.name,
+ suffix,
+ metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
+ )
+ )
+ output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix))
+ output.extend(lines)
# Get rid of the weird colon things while we're at it
if mtype == "counter":
@@ -172,16 +170,16 @@ def generate_latest(registry, emit_help=False):
)
)
output.append("# TYPE {0} {1}\n".format(mnewname, mtype))
- for sample in map(nameify_sample, metric.samples):
- # Get rid of the OpenMetrics specific samples
+
+ for s in metric.samples:
+ # Get rid of the OpenMetrics specific samples (we should already have
+ # dealt with them above anyway.)
for suffix in ["_created", "_gsum", "_gcount"]:
- if sample.name.endswith(suffix):
+ if s.name == metric.name + suffix:
break
else:
output.append(
- sample_line(
- sample, sample.name.replace(":total", "").replace(":", "_")
- )
+ sample_line(s, s.name.replace(":total", "").replace(":", "_"))
)
return "".join(output).encode("utf-8")
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index fcbd5378c4..b410e3ad9c 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -14,13 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from twisted.internet import defer
+from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.storage.state import StateFilter
from synapse.types import UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
"""
This package defines the 'stable' API which can be used by extension modules which
are loaded into Synapse.
@@ -43,6 +49,27 @@ class ModuleApi:
self._auth = hs.get_auth()
self._auth_handler = auth_handler
+ # We expose these as properties below in order to attach a helpful docstring.
+ self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
+ self._public_room_list_manager = PublicRoomListManager(hs)
+
+ @property
+ def http_client(self):
+ """Allows making outbound HTTP requests to remote resources.
+
+ An instance of synapse.http.client.SimpleHttpClient
+ """
+ return self._http_client
+
+ @property
+ def public_room_list_manager(self):
+ """Allows adding to, removing from and checking the status of rooms in the
+ public room list.
+
+ An instance of synapse.module_api.PublicRoomListManager
+ """
+ return self._public_room_list_manager
+
def get_user_by_req(self, req, allow_guest=False):
"""Check the access_token provided for a request
@@ -266,3 +293,70 @@ class ModuleApi:
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url,
)
+
+ @defer.inlineCallbacks
+ def get_state_events_in_room(
+ self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
+ ) -> defer.Deferred:
+ """Gets current state events for the given room.
+
+ (This is exposed for compatibility with the old SpamCheckerApi. We should
+ probably deprecate it and replace it with an async method in a subclass.)
+
+ Args:
+ room_id: The room ID to get state events in.
+ types: The event type and state key (using None
+ to represent 'any') of the room state to acquire.
+
+ Returns:
+ twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
+ The filtered state events in the room.
+ """
+ state_ids = yield defer.ensureDeferred(
+ self._store.get_filtered_current_state_ids(
+ room_id=room_id, state_filter=StateFilter.from_types(types)
+ )
+ )
+ state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
+ return state.values()
+
+
+class PublicRoomListManager:
+ """Contains methods for adding to, removing from and querying whether a room
+ is in the public room list.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._store = hs.get_datastore()
+
+ async def room_is_in_public_room_list(self, room_id: str) -> bool:
+ """Checks whether a room is in the public room list.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ Whether the room is in the public room list. Returns False if the room does
+ not exist.
+ """
+ room = await self._store.get_room(room_id)
+ if not room:
+ return False
+
+ return room.get("is_public", False)
+
+ async def add_room_to_public_room_list(self, room_id: str) -> None:
+ """Publishes a room to the public room list.
+
+ Args:
+ room_id: The ID of the room.
+ """
+ await self._store.set_room_is_public(room_id, True)
+
+ async def remove_room_from_public_room_list(self, room_id: str) -> None:
+ """Removes a room from the public room list.
+
+ Args:
+ room_id: The ID of the room.
+ """
+ await self._store.set_room_is_public(room_id, False)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index a8fd3ef886..59415f6f88 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -42,7 +42,13 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig
-from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StreamToken,
+ UserID,
+)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -157,7 +163,7 @@ class _NotifierUserStream:
"""
# Immediately wake up stream if something has already since happened
# since their last token.
- if self.last_notified_token.is_after(token):
+ if self.last_notified_token != token:
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
@@ -187,7 +193,7 @@ class Notifier:
self.store = hs.get_datastore()
self.pending_new_room_events = (
[]
- ) # type: List[Tuple[int, EventBase, Collection[UserID]]]
+ ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -246,8 +252,8 @@ class Notifier:
def on_new_room_event(
self,
event: EventBase,
- room_stream_id: int,
- max_room_stream_id: int,
+ event_pos: PersistedEventPosition,
+ max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
""" Used by handlers to inform the notifier something has happened
@@ -261,16 +267,16 @@ class Notifier:
until all previous events have been persisted before notifying
the client streams.
"""
- self.pending_new_room_events.append((room_stream_id, event, extra_users))
- self._notify_pending_new_room_events(max_room_stream_id)
+ self.pending_new_room_events.append((event_pos, event, extra_users))
+ self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication()
- def _notify_pending_new_room_events(self, max_room_stream_id: int):
+ def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
- max_room_stream_id: The highest stream_id below which all
+ max_room_stream_token: The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
@@ -279,11 +285,9 @@ class Notifier:
users = set() # type: Set[UserID]
rooms = set() # type: Set[str]
- for room_stream_id, event, extra_users in pending:
- if room_stream_id > max_room_stream_id:
- self.pending_new_room_events.append(
- (room_stream_id, event, extra_users)
- )
+ for event_pos, event, extra_users in pending:
+ if event_pos.persisted_after(max_room_stream_token):
+ self.pending_new_room_events.append((event_pos, event, extra_users))
else:
if (
event.type == EventTypes.Member
@@ -296,39 +300,38 @@ class Notifier:
if users or rooms:
self.on_new_event(
- "room_key",
- RoomStreamToken(None, max_room_stream_id),
- users=users,
- rooms=rooms,
+ "room_key", max_room_stream_token, users=users, rooms=rooms,
)
- self._on_updated_room_token(max_room_stream_id)
+ self._on_updated_room_token(max_room_stream_token)
- def _on_updated_room_token(self, max_room_stream_id: int):
+ def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
"""Poke services that might care that the room position has been
updated.
"""
# poke any interested application service.
run_as_background_process(
- "_notify_app_services", self._notify_app_services, max_room_stream_id
+ "_notify_app_services", self._notify_app_services, max_room_stream_token
)
run_as_background_process(
- "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
+ "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token
)
if self.federation_sender:
- self.federation_sender.notify_new_events(max_room_stream_id)
+ self.federation_sender.notify_new_events(max_room_stream_token.stream)
- async def _notify_app_services(self, max_room_stream_id: int):
+ async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
try:
- await self.appservice_handler.notify_interested_services(max_room_stream_id)
+ await self.appservice_handler.notify_interested_services(
+ max_room_stream_token.stream
+ )
except Exception:
logger.exception("Error notifying application services of event")
- async def _notify_pusher_pool(self, max_room_stream_id: int):
+ async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
- await self._pusher_pool.on_new_notifications(max_room_stream_id)
+ await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
except Exception:
logger.exception("Error pusher pool of event")
@@ -467,7 +470,7 @@ class Notifier:
async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
- if not after_token.is_after(before_token):
+ if after_token == before_token:
return EventStreamResult([], (from_token, from_token))
events = [] # type: List[EventBase]
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index edf45dc599..5a437f9810 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -16,4 +16,4 @@
class PusherConfigException(Exception):
def __init__(self, msg):
- super(PusherConfigException, self).__init__(msg)
+ super().__init__(msg)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 709ace01e5..3a68ce636f 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
import logging
import re
-from typing import Any, Dict, List, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Union
from synapse.events import EventBase
from synapse.types import UserID
@@ -181,7 +181,7 @@ class PushRuleEvaluatorForEvent:
return r.search(body)
- def _get_value(self, dotted_key: str) -> str:
+ def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index cc839ffce4..76150e117b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -60,6 +60,8 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ self._account_validity = hs.config.account_validity
+
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
@@ -202,6 +204,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_notifications(max_stream_id)
@@ -222,6 +232,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index ff0c67228b..0ddead8a0f 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -37,6 +37,9 @@ logger = logging.getLogger(__name__)
# installed when that optional dependency requirement is specified. It is passed
# to setup() as extras_require in setup.py
#
+# Note that these both represent runtime dependencies (and the versions
+# installed are checked at runtime).
+#
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
@@ -65,7 +68,11 @@ REQUIREMENTS = [
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
- "prometheus_client>=0.0.18,<0.9.0",
+ # we use GaugeHistogramMetric, which was added in prom-client 0.4.0.
+ # prom-client has a history of breaking backwards compatibility between
+ # minor versions (https://github.com/prometheus/client_python/issues/317),
+ # so we also pin the minor version.
+ "prometheus_client>=0.4.0,<0.9.0",
# we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
# Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
# is out in November.)
@@ -92,12 +99,6 @@ CONDITIONAL_REQUIREMENTS = {
"oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
- # Dependencies which are exclusively required by unit test code. This is
- # NOT a list of all modules that are necessary to run the unit tests.
- # Tests assume that all optional dependencies are installed.
- #
- # parameterized_class decorator was introduced in parameterized 0.7.0
- "test": ["mock>=2.0", "parameterized>=0.7.0"],
"sentry": ["sentry-sdk>=0.7.2"],
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
@@ -110,6 +111,7 @@ ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
# Exclude systemd as it's a system-based requirement.
+ # Exclude lint as it's a dev-based requirement.
if name not in ["systemd"]:
ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index ba16f22c91..64edadb624 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -20,20 +20,30 @@ import urllib
from inspect import signature
from typing import Dict, List, Tuple
-from synapse.api.errors import (
- CodeMessageException,
- HttpResponseException,
- RequestSendFailed,
- SynapseError,
-)
+from prometheus_client import Counter, Gauge
+
+from synapse.api.errors import HttpResponseException, SynapseError
+from synapse.http import RequestTimedOutError
from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
+_pending_outgoing_requests = Gauge(
+ "synapse_pending_outgoing_replication_requests",
+ "Number of active outgoing replication requests, by replication method name",
+ ["name"],
+)
+
+_outgoing_request_counter = Counter(
+ "synapse_outgoing_replication_requests",
+ "Number of outgoing replication requests, by replication method name and result",
+ ["name", "code"],
+)
+
-class ReplicationEndpoint:
+class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
@@ -72,8 +82,6 @@ class ReplicationEndpoint:
is received.
"""
- __metaclass__ = abc.ABCMeta
-
NAME = abc.abstractproperty() # type: str # type: ignore
PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
@@ -140,7 +148,10 @@ class ReplicationEndpoint:
instance_map = hs.config.worker.instance_map
+ outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
+
@trace(opname="outgoing_replication_request")
+ @outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
@@ -195,23 +206,26 @@ class ReplicationEndpoint:
try:
result = await request_func(uri, data, headers=headers)
break
- except CodeMessageException as e:
- if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
+ except RequestTimedOutError:
+ if not cls.RETRY_ON_TIMEOUT:
raise
- logger.warning("%s request timed out", cls.NAME)
+ logger.warning("%s request timed out; retrying", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
- # on the master process that we should send to the client. (And
+ # on the main process that we should send to the client. (And
# importantly, not stack traces everywhere)
+ _outgoing_request_counter.labels(cls.NAME, e.code).inc()
raise e.to_synapse_error()
- except RequestSendFailed as e:
- raise SynapseError(502, "Failed to talk to master") from e
+ except Exception as e:
+ _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
+ raise SynapseError(502, "Failed to talk to main process") from e
+ _outgoing_request_counter.labels(cls.NAME, 200).inc()
return result
return send_request
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 20f3ba76c0..807b85d2e1 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -53,7 +53,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
CACHE = False
def __init__(self, hs):
- super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.device_list_updater = hs.get_device_handler().device_list_updater
self.store = hs.get_datastore()
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5c8be747e1..5393b9a9e7 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -57,7 +57,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
PATH_ARGS = ()
def __init__(self, hs):
- super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.storage = hs.get_storage()
@@ -150,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
PATH_ARGS = ("edu_type",)
def __init__(self, hs):
- super(ReplicationFederationSendEduRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -193,7 +193,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
CACHE = False
def __init__(self, hs):
- super(ReplicationGetQueryRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -236,7 +236,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
PATH_ARGS = ("room_id",)
def __init__(self, hs):
- super(ReplicationCleanRoomRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index fb326bb869..4c81e2d784 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -32,7 +32,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id",)
def __init__(self, hs):
- super(RegisterDeviceReplicationServlet, self).__init__(hs)
+ super().__init__(hs)
self.registration_handler = hs.get_registration_handler()
@staticmethod
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 08095fdf7d..30680baee8 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs):
- super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
@@ -107,7 +107,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
PATH_ARGS = ("invite_event_id",)
def __init__(self, hs: "HomeServer"):
- super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -168,7 +168,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
CACHE = False # No point caching as should return instantly.
def __init__(self, hs):
- super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.registeration_handler = hs.get_registration_handler()
self.store = hs.get_datastore()
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index a02b27474d..7b12ec9060 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -29,7 +29,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id",)
def __init__(self, hs):
- super(ReplicationRegisterServlet, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler()
@@ -104,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id",)
def __init__(self, hs):
- super(ReplicationPostRegisterActionsServlet, self).__init__(hs)
+ super().__init__(hs)
self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler()
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index f13d452426..9a3a694d5d 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -52,7 +52,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
PATH_ARGS = ("event_id",)
def __init__(self, hs):
- super(ReplicationSendEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 60f2e1245f..d0089fe06c 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -26,16 +26,18 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(BaseSlavedStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
+ stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
+ writers=[],
) # type: Optional[MultiWriterIdGenerator]
else:
self._cache_id_gen = None
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bb66ba9b80..4268565fc8 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -34,7 +34,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
],
)
- super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index a6fdedde63..1f8dafe7ea 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -22,7 +22,7 @@ from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 533d927701..5b045bed02 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id"
)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 3b788c9625..e0d86240dd 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index da1cc836cf..fbffe6d85c 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -56,7 +56,7 @@ class SlavedEventStore(
BaseSlavedStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedEventStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 2562b6fc38..6a23252861 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -21,7 +21,7 @@ from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 567b4a5cc1..30955bcbfe 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 025f6f6be8..55620c03d8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -25,7 +25,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 9da218bfe8..c418730ba8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SlavedPusherStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 5c2986e050..6195917376 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -30,7 +30,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
db_conn, "receipts_linearized", "stream_id"
)
- super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 80ae803ad9..109ac6bea1 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -23,7 +23,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e82b9e386f..e165429cad 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
EventsStreamRow,
)
-from synapse.types import UserID
+from synapse.types import PersistedEventPosition, UserID
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -151,8 +151,12 @@ class ReplicationDataHandler:
extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ max_token = self.store.get_room_max_token()
+ event_pos = PersistedEventPosition(instance_name, token)
+ self.notifier.on_new_room_event(
+ event, event_pos, max_token, extra_users
+ )
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..e92da7b263 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -251,10 +251,9 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
- import txredisapi
-
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
+ lazyConnection,
)
logger.info(
@@ -271,7 +270,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = txredisapi.lazyConnection(
+ outbound_redis_connection = lazyConnection(
+ reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
import logging
import struct
from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
- self.time_we_closed = None # When we requested the connection be closed
+ # When we requested the connection be closed
+ self.time_we_closed = None # type: Optional[int]
- self.received_ping = False # Have we reecived a ping from the other side
+ self.received_ping = False # Have we received a ping from the other side
self.state = ConnectionStates.CONNECTING
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
- self._send_ping_loop = None
+ self._send_ping_loop = None # type: Optional[task.LoopingCall]
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..de19705c1f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import txredisapi
@@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password
return p
+
+
+def lazyConnection(
+ reactor,
+ host: str = "localhost",
+ port: int = 6379,
+ dbid: Optional[int] = None,
+ reconnect: bool = True,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ connectTimeout: Optional[int] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+ """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+ reactor.
+ """
+
+ isLazy = True
+ poolsize = 1
+
+ uuid = "%s:%d" % (host, port)
+ factory = txredisapi.RedisFactory(
+ uuid,
+ dbid,
+ poolsize,
+ isLazy,
+ txredisapi.ConnectionHandler,
+ charset,
+ password,
+ replyTimeout,
+ convertNumbers,
+ )
+ factory.continueTrying = reconnect
+ for x in range(poolsize):
+ reactor.connectTCP(host, port, factory, connectTimeout)
+
+ return factory.handler
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 1f609f158c..54dccd15a6 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -345,7 +345,7 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
- super(PushRulesStream, self).__init__(
+ super().__init__(
hs.get_instance_name(),
self._current_token,
self.store.get_all_push_rule_updates,
diff --git a/synapse/res/templates/auth_success.html b/synapse/res/templates/auth_success.html
new file mode 100644
index 0000000000..baf4633142
--- /dev/null
+++ b/synapse/res/templates/auth_success.html
@@ -0,0 +1,21 @@
+<html>
+<head>
+<title>Success!</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+<script>
+if (window.onAuthDone) {
+ window.onAuthDone();
+} else if (window.opener && window.opener.postMessage) {
+ window.opener.postMessage("authDone", "*");
+}
+</script>
+</head>
+<body>
+ <div>
+ <p>Thank you</p>
+ <p>You may now close this window and return to the application</p>
+ </div>
+</body>
+</html>
diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html
new file mode 100644
index 0000000000..63944dc608
--- /dev/null
+++ b/synapse/res/templates/recaptcha.html
@@ -0,0 +1,38 @@
+<html>
+<head>
+<title>Authentication</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<script src="https://www.recaptcha.net/recaptcha/api.js"
+ async defer></script>
+<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+<script>
+function captchaDone() {
+ $('#registrationForm').submit();
+}
+</script>
+</head>
+<body>
+<form id="registrationForm" method="post" action="{{ myurl }}">
+ <div>
+ <p>
+ Hello! We need to prevent computer programs and other automated
+ things from creating accounts on this server.
+ </p>
+ <p>
+ Please verify that you're not a robot.
+ </p>
+ <input type="hidden" name="session" value="{{ session }}" />
+ <div class="g-recaptcha"
+ data-sitekey="{{ sitekey }}"
+ data-callback="captchaDone">
+ </div>
+ <noscript>
+ <input type="submit" value="All Done" />
+ </noscript>
+ </div>
+ </div>
+</form>
+</body>
+</html>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index af8459719a..944bc9c9ca 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -12,7 +12,7 @@
<p>
There was an error during authentication:
</p>
- <div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
+ <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html
new file mode 100644
index 0000000000..dfef9897ee
--- /dev/null
+++ b/synapse/res/templates/terms.html
@@ -0,0 +1,20 @@
+<html>
+<head>
+<title>Authentication</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+ user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+</head>
+<body>
+<form id="registrationForm" method="post" action="{{ myurl }}">
+ <div>
+ <p>
+ Please click the button below if you agree to the
+ <a href="{{ terms_url }}">privacy policy of this homeserver.</a>
+ </p>
+ <input type="hidden" name="session" value="{{ session }}" />
+ <input type="submit" value="Agree" />
+ </div>
+</form>
+</body>
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 1c88c93f38..789431ef25 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -16,13 +16,13 @@
import logging
import platform
-import re
import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
+ admin_patterns,
assert_requester_is_admin,
historical_admin_path_patterns,
)
@@ -31,6 +31,7 @@ from synapse.rest.admin.devices import (
DeviceRestServlet,
DevicesRestServlet,
)
+from synapse.rest.admin.event_reports import EventReportsRestServlet
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
@@ -49,19 +50,21 @@ from synapse.rest.admin.users import (
ResetPasswordRestServlet,
SearchUsersRestServlet,
UserAdminServlet,
+ UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
UsersRestServlet,
UsersRestServletV2,
WhoisRestServlet,
)
+from synapse.types import RoomStreamToken
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
class VersionServlet(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
+ PATTERNS = admin_patterns("/server_version$")
def __init__(self, hs):
self.res = {
@@ -107,7 +110,10 @@ class PurgeHistoryRestServlet(RestServlet):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- token = await self.store.get_topological_token_for_event(event_id)
+ room_token = RoomStreamToken(
+ event.depth, event.internal_metadata.stream_ordering
+ )
+ token = await room_token.to_string(self.store)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body:
@@ -209,11 +215,13 @@ def register_servlets(hs, http_server):
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
+ UserMembershipRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
+ EventReportsRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d82eaf5e38..db9fea263a 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -44,7 +44,7 @@ def historical_admin_path_patterns(path_regex):
]
-def admin_patterns(path_regex: str):
+def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
Args:
@@ -54,7 +54,7 @@ def admin_patterns(path_regex: str):
Returns:
A list of regex patterns.
"""
- admin_prefix = "^/_synapse/admin/v1"
+ admin_prefix = "^/_synapse/admin/" + version
patterns = [re.compile(admin_prefix + path_regex)]
return patterns
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 8d32677339..a163863322 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import re
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -21,7 +20,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
-from synapse.rest.admin._base import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -32,14 +31,12 @@ class DeviceRestServlet(RestServlet):
Get, update or delete the given user's device
"""
- PATTERNS = (
- re.compile(
- "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$"
- ),
+ PATTERNS = admin_patterns(
+ "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
def __init__(self, hs):
- super(DeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -98,7 +95,7 @@ class DevicesRestServlet(RestServlet):
Retrieve the given user's devices
"""
- PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs):
"""
@@ -131,9 +128,7 @@ class DeleteDevicesRestServlet(RestServlet):
key which lists the device_ids to delete.
"""
- PATTERNS = (
- re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"),
- )
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs):
self.hs = hs
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
new file mode 100644
index 0000000000..5b8d0594cd
--- /dev/null
+++ b/synapse/rest/admin/event_reports.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+
+logger = logging.getLogger(__name__)
+
+
+class EventReportsRestServlet(RestServlet):
+ """
+ List all reported events that are known to the homeserver. Results are returned
+ in a dictionary containing report information. Supports pagination.
+ The requester must have administrator access in Synapse.
+
+ GET /_synapse/admin/v1/event_reports
+ returns:
+ 200 OK with list of reports if success otherwise an error.
+
+ Args:
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ The parameter `dir` can be used to define the order of results.
+ The parameter `user_id` can be used to filter by user id.
+ The parameter `room_id` can be used to filter by room id.
+ Returns:
+ A list of reported events and an integer representing the total number of
+ reported events that exist given this query
+ """
+
+ PATTERNS = admin_patterns("/event_reports$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ direction = parse_string(request, "dir", default="b")
+ user_id = parse_string(request, "user_id")
+ room_id = parse_string(request, "room_id")
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "The start parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "The limit parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ event_reports, total = await self.store.get_event_reports_paginate(
+ start, limit, direction, user_id, room_id
+ )
+ ret = {"event_reports": event_reports, "total": total}
+ if (start + limit) < total:
+ ret["next_token"] = start + len(event_reports)
+
+ return 200, ret
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index f474066542..8b7bb6d44e 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,14 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
-
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
class PurgeRoomServlet(RestServlet):
@@ -35,7 +34,7 @@ class PurgeRoomServlet(RestServlet):
{}
"""
- PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),)
+ PATTERNS = admin_patterns("/purge_room$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 6e9a874121..375d055445 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
-
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -22,6 +20,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import UserID
@@ -56,13 +55,13 @@ class SendServerNoticeServlet(RestServlet):
self.snm = hs.get_server_notices_manager()
def register(self, json_resource):
- PATTERN = "^/_synapse/admin/v1/send_server_notice"
+ PATTERN = "/send_server_notice"
json_resource.register_paths(
- "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
+ "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
)
json_resource.register_paths(
"PUT",
- (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
+ admin_patterns(PATTERN + "/(?P<txn_id>[^/]*)$"),
self.on_PUT,
self.__class__.__name__,
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f3e77da850..20dc1d0e05 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -15,7 +15,6 @@
import hashlib
import hmac
import logging
-import re
from http import HTTPStatus
from synapse.api.constants import UserTypes
@@ -29,6 +28,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.rest.admin._base import (
+ admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
@@ -60,7 +60,7 @@ class UsersRestServlet(RestServlet):
class UsersRestServletV2(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
+ PATTERNS = admin_patterns("/users$", "v2")
"""Get request to list all local users.
This needs user to have administrator access in Synapse.
@@ -105,7 +105,7 @@ class UsersRestServletV2(RestServlet):
class UserRestServletV2(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
"""Get request to list user details.
This needs user to have administrator access in Synapse.
@@ -642,7 +642,7 @@ class UserAdminServlet(RestServlet):
{}
"""
- PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs):
self.hs = hs
@@ -683,3 +683,29 @@ class UserAdminServlet(RestServlet):
await self.store.set_server_admin(target_user, set_admin_to)
return 200, {}
+
+
+class UserMembershipRestServlet(RestServlet):
+ """
+ Get room list of an user.
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
+
+ def __init__(self, hs):
+ self.is_mine = hs.is_mine
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.is_mine(UserID.from_string(user_id)):
+ raise SynapseError(400, "Can only lookup local users")
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ raise NotFoundError("User not found")
+
+ ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
+ return 200, ret
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index b210015173..faabeeb91c 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryListServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
)
def __init__(self, hs):
- super(ClientAppserviceDirectoryListServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 25effd0261..1ecb77aa26 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -30,9 +30,10 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
- super(EventStreamRestServlet, self).__init__()
+ super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet):
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode("ascii")
- pagin_config = PaginationConfig.from_request(request)
+ pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args:
try:
@@ -74,7 +75,7 @@ class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(EventRestServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 910b3b4eeb..91da0ee573 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -24,14 +24,15 @@ class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs):
- super(InitialSyncRestServlet, self).__init__()
+ super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
- pagination_config = PaginationConfig.from_request(request)
+ pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a14618ac84..3d1693d7ac 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.appservice import ApplicationService
from synapse.handlers.auth import (
convert_client_dict_legacy_fields_to_identifier,
login_id_phone_to_thirdparty,
@@ -44,9 +45,10 @@ class LoginRestServlet(RestServlet):
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
+ APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
def __init__(self, hs):
- super(LoginRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
# JWT configuration variables.
@@ -61,6 +63,8 @@ class LoginRestServlet(RestServlet):
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self.auth = hs.get_auth()
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -116,8 +120,12 @@ class LoginRestServlet(RestServlet):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
+
try:
- if self.jwt_enabled and (
+ if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
+ appservice = self.auth.get_appservice_by_req(request)
+ result = await self._do_appservice_login(login_submission, appservice)
+ elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
@@ -134,6 +142,33 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
+ def _get_qualified_user_id(self, identifier):
+ if identifier["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+ if "user" not in identifier:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if identifier["user"].startswith("@"):
+ return identifier["user"]
+ else:
+ return UserID(identifier["user"], self.hs.hostname).to_string()
+
+ async def _do_appservice_login(
+ self, login_submission: JsonDict, appservice: ApplicationService
+ ):
+ logger.info(
+ "Got appservice login request with identifier: %r",
+ login_submission.get("identifier"),
+ )
+
+ identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
+ qualified_user_id = self._get_qualified_user_id(identifier)
+
+ if not appservice.is_interested_in_user(qualified_user_id):
+ raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
+
+ return await self._complete_login(qualified_user_id, login_submission)
+
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -219,15 +254,7 @@ class LoginRestServlet(RestServlet):
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
- if identifier["type"] != "m.id.user":
- raise SynapseError(400, "Unknown login identifier type")
- if "user" not in identifier:
- raise SynapseError(400, "User identifier is missing 'user' key")
-
- if identifier["user"].startswith("@"):
- qualified_user_id = identifier["user"]
- else:
- qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+ qualified_user_id = self._get_qualified_user_id(identifier)
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
@@ -255,9 +282,7 @@ class LoginRestServlet(RestServlet):
self,
user_id: str,
login_submission: JsonDict,
- callback: Optional[
- Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
- ] = None,
+ callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
@@ -270,12 +295,12 @@ class LoginRestServlet(RestServlet):
Args:
user_id: ID of the user to register.
login_submission: Dictionary of login information.
- callback: Callback function to run after registration.
+ callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
Returns:
- result: Dictionary of account information after successful registration.
+ result: Dictionary of account information after successful login.
"""
# Before we actually log them in we check if they've already logged in
@@ -310,14 +335,24 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ """
+ Handle the final stage of SSO login.
+
+ Args:
+ login_submission: The JSON request body.
+
+ Returns:
+ The body of the JSON response.
+ """
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
- result = await self._complete_login(user_id, login_submission)
- return result
+ return await self._complete_login(
+ user_id, login_submission, self.auth_handler._sso_login_callback
+ )
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
@@ -400,7 +435,7 @@ class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs):
- super(CasTicketServlet, self).__init__()
+ super().__init__()
self._cas_handler = hs.get_cas_handler()
async def on_GET(self, request: SynapseRequest) -> None:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index b0c30b65be..f792b50cdc 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs):
- super(LogoutRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs):
- super(LogoutAllRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 970fdd5834..79d8e3057f 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs):
- super(PresenceStatusRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..b686cd671f 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -25,7 +25,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs):
- super(ProfileDisplaynameRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@@ -73,7 +73,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs):
- super(ProfileAvatarURLRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@@ -124,7 +124,7 @@ class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs):
- super(ProfileRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index ddf8ed5e9c..f9eecb7cf5 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -38,7 +38,7 @@ class PushRuleRestServlet(RestServlet):
)
def __init__(self, hs):
- super(PushRuleRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 5f65cb7d83..28dabf1c7a 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs):
- super(PushersRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs):
- super(PushersSetRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
- super(PushersRemoveRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 84baf3d59b..b63389e5fe 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
def __init__(self, hs):
- super(TransactionRestServlet, self).__init__()
+ super().__init__()
self.txns = HttpTransactionCache(hs)
@@ -65,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
- super(RoomCreateRestServlet, self).__init__(hs)
+ super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
@@ -111,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomStateEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
@@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomSendEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -280,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs):
- super(JoinRoomAliasServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -343,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
- super(PublicRoomListRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@@ -448,9 +448,10 @@ class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs):
- super(RoomMemberListRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
if at_token_string is None:
at_token = None
else:
- at_token = StreamToken.from_string(at_token_string)
+ at_token = await StreamToken.from_string(self.store, at_token_string)
# let you filter down on particular memberships.
# XXX: this may not be the best shape for this API - we could pass in a filter
@@ -499,7 +500,7 @@ class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs):
- super(JoinedRoomMemberListRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@@ -518,13 +519,16 @@ class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs):
- super(RoomMessageListRestServlet, self).__init__()
+ super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(request, default_limit=10)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
@@ -557,7 +561,7 @@ class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs):
- super(RoomStateRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@@ -577,13 +581,14 @@ class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs):
- super(RoomInitialSyncRestServlet, self).__init__()
+ super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(request)
+ pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
@@ -596,7 +601,7 @@ class RoomEventServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomEventServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
@@ -628,7 +633,7 @@ class RoomEventContextServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomEventContextServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
@@ -675,7 +680,7 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomForgetRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -701,7 +706,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomMembershipRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -792,7 +797,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomRedactEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -841,7 +846,7 @@ class RoomTypingRestServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomTypingRestServlet, self).__init__()
+ super().__init__()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
@@ -914,7 +919,7 @@ class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs):
- super(SearchRestServlet, self).__init__()
+ super().__init__()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -935,7 +940,7 @@ class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs):
- super(JoinedRoomsRestServlet, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 50277c6cf6..b8d491ca5c 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
- super(VoipRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index ade97a6708..ab5815e7f7 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -52,7 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
- super(EmailPasswordRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
self.config = hs.config
@@ -96,15 +96,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
- raise SynapseError(
- 403,
- "Your email domain is not authorized on this server",
- Codes.THREEPID_DENIED,
- )
-
- # Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
@@ -156,7 +150,7 @@ class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
def __init__(self, hs):
- super(PasswordRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -282,7 +276,7 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs):
- super(DeactivateAccountRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -330,7 +324,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
- super(EmailThreepidRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.config = hs.config
self.identity_handler = hs.get_handlers().identity_handler
@@ -379,8 +373,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- # Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
@@ -427,7 +422,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
def __init__(self, hs):
self.hs = hs
- super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@@ -453,8 +448,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- # Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
@@ -606,7 +602,7 @@ class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs):
- super(ThreepidRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -662,7 +658,7 @@ class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
def __init__(self, hs):
- super(ThreepidAddRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -713,7 +709,7 @@ class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
def __init__(self, hs):
- super(ThreepidBindRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -742,7 +738,7 @@ class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$")
def __init__(self, hs):
- super(ThreepidUnbindRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -773,7 +769,7 @@ class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs):
- super(ThreepidDeleteRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -852,7 +848,7 @@ class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs):
- super(WhoamiRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
async def on_GET(self, request):
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..87a5b1b86b 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -34,7 +34,7 @@ class AccountDataServlet(RestServlet):
)
def __init__(self, hs):
- super(AccountDataServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -86,7 +86,7 @@ class RoomAccountDataServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomAccountDataServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index d06336ceea..bd7f9ae203 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(AccountValidityRenewServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
@@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(AccountValiditySendMailServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e585e9153..5fbfae5991 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -25,94 +25,6 @@ from ._base import client_patterns
logger = logging.getLogger(__name__)
-RECAPTCHA_TEMPLATE = """
-<html>
-<head>
-<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<script src="https://www.recaptcha.net/recaptcha/api.js"
- async defer></script>
-<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-<script>
-function captchaDone() {
- $('#registrationForm').submit();
-}
-</script>
-</head>
-<body>
-<form id="registrationForm" method="post" action="%(myurl)s">
- <div>
- <p>
- Hello! We need to prevent computer programs and other automated
- things from creating accounts on this server.
- </p>
- <p>
- Please verify that you're not a robot.
- </p>
- <input type="hidden" name="session" value="%(session)s" />
- <div class="g-recaptcha"
- data-sitekey="%(sitekey)s"
- data-callback="captchaDone">
- </div>
- <noscript>
- <input type="submit" value="All Done" />
- </noscript>
- </div>
- </div>
-</form>
-</body>
-</html>
-"""
-
-TERMS_TEMPLATE = """
-<html>
-<head>
-<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-</head>
-<body>
-<form id="registrationForm" method="post" action="%(myurl)s">
- <div>
- <p>
- Please click the button below if you agree to the
- <a href="%(terms_url)s">privacy policy of this homeserver.</a>
- </p>
- <input type="hidden" name="session" value="%(session)s" />
- <input type="submit" value="Agree" />
- </div>
-</form>
-</body>
-</html>
-"""
-
-SUCCESS_TEMPLATE = """
-<html>
-<head>
-<title>Success!</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-<script>
-if (window.onAuthDone) {
- window.onAuthDone();
-} else if (window.opener && window.opener.postMessage) {
- window.opener.postMessage("authDone", "*");
-}
-</script>
-</head>
-<body>
- <div>
- <p>Thank you</p>
- <p>You may now close this window and return to the application</p>
- </div>
-</body>
-</html>
-"""
-
class AuthRestServlet(RestServlet):
"""
@@ -124,7 +36,7 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
- super(AuthRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -145,26 +57,30 @@ class AuthRestServlet(RestServlet):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
+ self.recaptcha_template = hs.config.recaptcha_template
+ self.terms_template = hs.config.terms_template
+ self.success_template = hs.config.fallback_success_template
+
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
if stagetype == LoginType.RECAPTCHA:
- html = RECAPTCHA_TEMPLATE % {
- "session": session,
- "myurl": "%s/r0/auth/%s/fallback/web"
+ html = self.recaptcha_template.render(
+ session=session,
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- "sitekey": self.hs.config.recaptcha_public_key,
- }
+ sitekey=self.hs.config.recaptcha_public_key,
+ )
elif stagetype == LoginType.TERMS:
- html = TERMS_TEMPLATE % {
- "session": session,
- "terms_url": "%s_matrix/consent?v=%s"
+ html = self.terms_template.render(
+ session=session,
+ terms_url="%s_matrix/consent?v=%s"
% (self.hs.config.public_baseurl, self.hs.config.user_consent_version),
- "myurl": "%s/r0/auth/%s/fallback/web"
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
- }
+ )
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
@@ -222,14 +138,14 @@ class AuthRestServlet(RestServlet):
)
if success:
- html = SUCCESS_TEMPLATE
+ html = self.success_template.render()
else:
- html = RECAPTCHA_TEMPLATE % {
- "session": session,
- "myurl": "%s/r0/auth/%s/fallback/web"
+ html = self.recaptcha_template.render(
+ session=session,
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- "sitekey": self.hs.config.recaptcha_public_key,
- }
+ sitekey=self.hs.config.recaptcha_public_key,
+ )
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -238,18 +154,18 @@ class AuthRestServlet(RestServlet):
)
if success:
- html = SUCCESS_TEMPLATE
+ html = self.success_template.render()
else:
- html = TERMS_TEMPLATE % {
- "session": session,
- "terms_url": "%s_matrix/consent?v=%s"
+ html = self.terms_template.render(
+ session=session,
+ terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- "myurl": "%s/r0/auth/%s/fallback/web"
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
- }
+ )
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fe9d019c44..76879ac559 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(CapabilitiesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.config = hs.config
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index c0714fcfb1..af117cb27c 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +22,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from ._base import client_patterns, interactive_auth_handler
@@ -35,7 +37,7 @@ class DevicesRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(DevicesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
- super(DeleteDevicesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -102,7 +104,7 @@ class DeviceRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(DeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -151,7 +153,139 @@ class DeviceRestServlet(RestServlet):
return 200, {}
+class DehydratedDeviceServlet(RestServlet):
+ """Retrieve or store a dehydrated device.
+
+ GET /org.matrix.msc2697.v2/dehydrated_device
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id",
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device"
+ }
+ }
+
+ PUT /org.matrix.msc2697/dehydrated_device
+ Content-Type: application/json
+
+ {
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device"
+ }
+ }
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id"
+ }
+
+ """
+
+ PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
+
+ def __init__(self, hs):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ async def on_GET(self, request: SynapseRequest):
+ requester = await self.auth.get_user_by_req(request)
+ dehydrated_device = await self.device_handler.get_dehydrated_device(
+ requester.user.to_string()
+ )
+ if dehydrated_device is not None:
+ (device_id, device_data) = dehydrated_device
+ result = {"device_id": device_id, "device_data": device_data}
+ return (200, result)
+ else:
+ raise errors.NotFoundError("No dehydrated device available")
+
+ async def on_PUT(self, request: SynapseRequest):
+ submission = parse_json_object_from_request(request)
+ requester = await self.auth.get_user_by_req(request)
+
+ if "device_data" not in submission:
+ raise errors.SynapseError(
+ 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
+ )
+ elif not isinstance(submission["device_data"], dict):
+ raise errors.SynapseError(
+ 400,
+ "device_data must be an object",
+ errcode=errors.Codes.INVALID_PARAM,
+ )
+
+ device_id = await self.device_handler.store_dehydrated_device(
+ requester.user.to_string(),
+ submission["device_data"],
+ submission.get("initial_device_display_name", None),
+ )
+ return 200, {"device_id": device_id}
+
+
+class ClaimDehydratedDeviceServlet(RestServlet):
+ """Claim a dehydrated device.
+
+ POST /org.matrix.msc2697.v2/dehydrated_device/claim
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id"
+ }
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "success": true,
+ }
+
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
+ )
+
+ def __init__(self, hs):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ async def on_POST(self, request: SynapseRequest):
+ requester = await self.auth.get_user_by_req(request)
+
+ submission = parse_json_object_from_request(request)
+
+ if "device_id" not in submission:
+ raise errors.SynapseError(
+ 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
+ )
+ elif not isinstance(submission["device_id"], str):
+ raise errors.SynapseError(
+ 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
+ )
+
+ result = await self.device_handler.rehydrate_device(
+ requester.user.to_string(),
+ self.auth.get_access_token_from_request(request),
+ submission["device_id"],
+ )
+
+ return (200, result)
+
+
def register_servlets(hs, http_server):
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
+ DehydratedDeviceServlet(hs).register(http_server)
+ ClaimDehydratedDeviceServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index b28da017cd..7cc692643b 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
- super(GetFilterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
- super(CreateFilterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 13ecf7005d..a3bb095c2d 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -32,7 +32,7 @@ class GroupServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
- super(GroupServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -66,7 +66,7 @@ class GroupSummaryServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
- super(GroupSummaryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -97,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupSummaryRoomsCatServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -137,7 +137,7 @@ class GroupCategoryServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupCategoryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -181,7 +181,7 @@ class GroupCategoriesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
def __init__(self, hs):
- super(GroupCategoriesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -204,7 +204,7 @@ class GroupRoleServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
def __init__(self, hs):
- super(GroupRoleServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -248,7 +248,7 @@ class GroupRolesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
def __init__(self, hs):
- super(GroupRolesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -279,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupSummaryUsersRoleServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -317,7 +317,7 @@ class GroupRoomServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
- super(GroupRoomServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -343,7 +343,7 @@ class GroupUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
- super(GroupUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -366,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
- super(GroupInvitedUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -389,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
- super(GroupSettingJoinPolicyServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@@ -413,7 +413,7 @@ class GroupCreateServlet(RestServlet):
PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
- super(GroupCreateServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -444,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminRoomsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -481,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminRoomsConfigServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -507,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminUsersInviteServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -536,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminUsersKickServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -560,7 +560,7 @@ class GroupSelfLeaveServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
def __init__(self, hs):
- super(GroupSelfLeaveServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -584,7 +584,7 @@ class GroupSelfJoinServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
def __init__(self, hs):
- super(GroupSelfJoinServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -608,7 +608,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
def __init__(self, hs):
- super(GroupSelfAcceptInviteServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -632,7 +632,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
def __init__(self, hs):
- super(GroupSelfUpdatePublicityServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -655,7 +655,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
def __init__(self, hs):
- super(PublicisedGroupsForUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -676,7 +676,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups$")
def __init__(self, hs):
- super(PublicisedGroupsForUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -700,7 +700,7 @@ class GroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/joined_groups$")
def __init__(self, hs):
- super(GroupsForUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 24bb090822..b91996c738 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -64,9 +65,10 @@ class KeyUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(KeyUploadServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
async def on_POST(self, request, device_id):
@@ -75,23 +77,28 @@ class KeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
if device_id is not None:
- # passing the device_id here is deprecated; however, we allow it
- # for now for compatibility with older clients.
+ # Providing the device_id should only be done for setting keys
+ # for dehydrated devices; however, we allow it for any device for
+ # compatibility with older clients.
if requester.device_id is not None and device_id != requester.device_id:
- set_tag("error", True)
- log_kv(
- {
- "message": "Client uploading keys for a different device",
- "logged_in_id": requester.device_id,
- "key_being_uploaded": device_id,
- }
- )
- logger.warning(
- "Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id,
- device_id,
+ dehydrated_device = await self.device_handler.get_dehydrated_device(
+ user_id
)
+ if dehydrated_device is not None and device_id != dehydrated_device[0]:
+ set_tag("error", True)
+ log_kv(
+ {
+ "message": "Client uploading keys for a different device",
+ "logged_in_id": requester.device_id,
+ "key_being_uploaded": device_id,
+ }
+ )
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
else:
device_id = requester.device_id
@@ -147,7 +154,7 @@ class KeyQueryServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- super(KeyQueryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -177,9 +184,10 @@ class KeyChangesServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- super(KeyChangesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -191,7 +199,7 @@ class KeyChangesServlet(RestServlet):
# changes after the "to" as well as before.
set_tag("to", parse_string(request, "to"))
- from_token = StreamToken.from_string(from_token_string)
+ from_token = await StreamToken.from_string(self.store, from_token_string)
user_id = requester.user.to_string()
@@ -222,7 +230,7 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
- super(OneTimeKeyServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -250,7 +258,7 @@ class SigningKeyUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SigningKeyUploadServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -308,7 +316,7 @@ class SignaturesUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SignaturesUploadServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index aa911d75ee..87063ec8b1 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$")
def __init__(self, hs):
- super(NotificationsServlet, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index 6ae9a5a8e9..5b996e2d63 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000
def __init__(self, hs):
- super(IdTokenServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
index 968403cca4..68b27ff23a 100644
--- a/synapse/rest/client/v2_alpha/password_policy.py
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(PasswordPolicyServlet, self).__init__()
+ super().__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 67cbc37312..55c6688f52 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs):
- super(ReadMarkerRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 92555bd4a9..6f7246a394 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet):
)
def __init__(self, hs):
- super(ReceiptRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b6b90a8b30..ffa2dfce42 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -76,7 +76,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(EmailRegisterRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.config = hs.config
@@ -174,7 +174,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@@ -249,7 +249,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegistrationSubmitTokenServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.config = hs.config
@@ -319,7 +319,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(UsernameAvailabilityRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
@@ -363,7 +363,7 @@ class RegisterRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegisterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -431,11 +431,14 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
- if isinstance(desired_username, str):
- result = await self._do_appservice_registration(
- desired_username, access_token, body
- )
- return 200, result # we throw for non 200 responses
+ if not isinstance(desired_username, str):
+ raise SynapseError(400, "Desired Username is missing or not a string")
+
+ result = await self._do_appservice_registration(
+ desired_username, access_token, body
+ )
+
+ return 200, result
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index e29f49f7f5..18c75738f8 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -61,7 +61,7 @@ class RelationSendServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationSendServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
@@ -138,7 +138,7 @@ class RelationPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -233,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationAggregationPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
@@ -311,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationAggregationGroupPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e15927c4ea..215d619ca1 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
def __init__(self, hs):
- super(ReportEventRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 59529707df..53de97923f 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysNewVersionServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysVersionServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index 39a5518614..bf030e0ff4 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -53,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomUpgradeRestServlet, self).__init__()
+ super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index db829f3098..bc4f43639a 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SendToDeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
index 2492634dac..c866d5151c 100644
--- a/synapse/rest/client/v2_alpha/shared_rooms.py
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -34,7 +34,7 @@ class UserSharedRoomsServlet(RestServlet):
)
def __init__(self, hs):
- super(UserSharedRoomsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a0b00135e1..2b84eb89c0 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -74,9 +74,10 @@ class SyncRestServlet(RestServlet):
ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
def __init__(self, hs):
- super(SyncRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
@@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet):
device_id=device_id,
)
+ since_token = None
if since is not None:
- since_token = StreamToken.from_string(since)
- else:
- since_token = None
+ since_token = await StreamToken.from_string(self.store, since)
# send any outstanding server notices to the user.
await self._server_notices_sender.on_user_syncing(user.to_string())
@@ -236,7 +236,8 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
- "next_batch": sync_result.next_batch.to_string(),
+ "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
+ "next_batch": await sync_result.next_batch.to_string(self.store),
}
@staticmethod
@@ -413,7 +414,7 @@ class SyncRestServlet(RestServlet):
result = {
"timeline": {
"events": serialized_timeline,
- "prev_batch": room.timeline.prev_batch.to_string(),
+ "prev_batch": await room.timeline.prev_batch.to_string(self.store),
"limited": room.timeline.limited,
},
"state": {"events": serialized_state},
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index a3f12e8a77..bf3a79db44 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -31,7 +31,7 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
def __init__(self, hs):
- super(TagListServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -56,7 +56,7 @@ class TagServlet(RestServlet):
)
def __init__(self, hs):
- super(TagServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 23709960ad..0c127a1b5f 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs):
- super(ThirdPartyProtocolsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
- super(ThirdPartyProtocolServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
- super(ThirdPartyUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
- super(ThirdPartyLocationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 83f3b6b70a..79317c74ba 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
- super(TokenRefreshRestServlet, self).__init__()
+ super().__init__()
async def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.")
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..ad598cefe0 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -31,7 +31,7 @@ class UserDirectorySearchRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(UserDirectorySearchRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 853657d020..7a5c739b23 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -19,6 +19,7 @@
import logging
import re
+from synapse.api.constants import RoomCreationPreset
from synapse.http.servlet import RestServlet
logger = logging.getLogger(__name__)
@@ -28,9 +29,23 @@ class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
def __init__(self, hs):
- super(VersionsRestServlet, self).__init__()
+ super().__init__()
self.config = hs.config
+ # Calculate these once since they shouldn't change after start-up.
+ self.e2ee_forced_public = (
+ RoomCreationPreset.PUBLIC_CHAT
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+ self.e2ee_forced_private = (
+ RoomCreationPreset.PRIVATE_CHAT
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+ self.e2ee_forced_trusted_private = (
+ RoomCreationPreset.TRUSTED_PRIVATE_CHAT
+ in self.config.encryption_enabled_by_default_for_room_presets
+ )
+
def on_GET(self, request):
return (
200,
@@ -62,6 +77,10 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc2432": True,
# Implements additional endpoints as described in MSC2666
"uk.half-shot.msc2666": True,
+ # Whether new rooms will be set to encrypted or not (based on presets).
+ "io.element.e2ee_forced.public": self.e2ee_forced_public,
+ "io.element.e2ee_forced.private": self.e2ee_forced_private,
+ "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
# Implements additional endpoints and features as described in MSC2403
"xyz.amorgan.knock": True,
},
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 6568e61829..67aa993f19 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -213,6 +213,12 @@ async def respond_with_responder(
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
"""
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
+
if not responder:
respond_404(request)
return
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 69f353d46f..e1192b47cd 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -139,7 +139,7 @@ class MediaRepository:
async def create_content(
self,
media_type: str,
- upload_name: str,
+ upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: str,
@@ -147,8 +147,8 @@ class MediaRepository:
"""Store uploaded content for a local user and return the mxc URL
Args:
- media_type: The content type of the file
- upload_name: The name of the file
+ media_type: The content type of the file.
+ upload_name: The name of the file, if provided.
content: A file like object that is the content to store
content_length: The length of the content
auth_user: The user_id of the uploader
@@ -156,6 +156,7 @@ class MediaRepository:
Returns:
The mxc url of the stored content
"""
+
media_id = random_string(24)
file_info = FileInfo(server_name=None, file_id=media_id)
@@ -636,7 +637,7 @@ class MediaRepository:
thumbnailer = Thumbnailer(input_path)
except ThumbnailError as e:
logger.warning(
- "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s",
+ "Unable to generate thumbnails for remote media %s from %s of type %s: %s",
media_id,
server_name,
media_type,
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 5681677fc9..a9586fb0b7 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -141,31 +141,34 @@ class MediaStorage:
Returns:
Returns a Responder if the file was found, otherwise None.
"""
+ paths = [self._file_info_to_path(file_info)]
- path = self._file_info_to_path(file_info)
- local_path = os.path.join(self.local_media_directory, path)
- if os.path.exists(local_path):
- return FileResponder(open(local_path, "rb"))
-
- # Fallback for paths without method names
- # Should be removed in the future
+ # fallback for remote thumbnails with no method in the filename
if file_info.thumbnail and file_info.server_name:
- legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
- server_name=file_info.server_name,
- file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ paths.append(
+ self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
)
- legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
- if os.path.exists(legacy_local_path):
- return FileResponder(open(legacy_local_path, "rb"))
+
+ for path in paths:
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ logger.debug("responding with local file %s", local_path)
+ return FileResponder(open(local_path, "rb"))
+ logger.debug("local file %s did not exist", local_path)
for provider in self.storage_providers:
- res = await provider.fetch(path, file_info) # type: Any
- if res:
- logger.debug("Streaming %s from %s", path, provider)
- return res
+ for path in paths:
+ res = await provider.fetch(path, file_info) # type: Any
+ if res:
+ logger.debug("Streaming %s from %s", path, provider)
+ return res
+ logger.debug("%s not found on %s", path, provider)
return None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 987765e877..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url, user):
+ async def _download_url(self, url: str, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
- url_to_download = url
+ url_to_download = url # type: Optional[str]
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
- etag = headers["ETag"][0] if "ETag" in headers else None
+ etag = (
+ headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+ )
else:
- html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ # we can only get here if we did an oembed request and have an oembed_result.html
+ assert oembed_result.html is not None
+ assert oembed_url is not None
+
+ html_bytes = oembed_result.html.encode("utf-8")
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 3ebf7a68e6..d76f7389e1 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -63,6 +63,10 @@ class UploadResource(DirectServeJsonResource):
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
)
+ # If the name is falsey (e.g. an empty byte string) ensure it is None.
+ else:
+ upload_name = None
+
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
diff --git a/synapse/server.py b/synapse/server.py
index 9055b97ac3..f83dd6148c 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -91,6 +91,7 @@ from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.module_api import ModuleApi
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
@@ -114,6 +115,7 @@ from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString
from synapse.util import Clock
from synapse.util.distributor import Distributor
+from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
@@ -184,7 +186,10 @@ class HomeServer(metaclass=abc.ABCMeta):
we are listening on to provide HTTP services.
"""
- REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
+ REQUIRED_ON_BACKGROUND_TASK_STARTUP = [
+ "auth",
+ "stats",
+ ]
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
@@ -250,14 +255,20 @@ class HomeServer(metaclass=abc.ABCMeta):
self.datastores = Databases(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.")
- def setup_master(self) -> None:
+ # Register background tasks required by this server. This must be done
+ # somewhat manually due to the background tasks not being registered
+ # unless handlers are instantiated.
+ if self.config.run_background_tasks:
+ self.setup_background_tasks()
+
+ def setup_background_tasks(self) -> None:
"""
Some handlers have side effects on instantiation (like registering
background updates). This function causes them to be fetched, and
therefore instantiated, to run those side effects.
"""
- for i in self.REQUIRED_ON_MASTER_STARTUP:
- getattr(self, "get_" + i)()
+ for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
+ getattr(self, "get_" + i + "_handler")()
def get_reactor(self) -> twisted.internet.base.ReactorBase:
"""
@@ -642,6 +653,14 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_replication_streams(self) -> Dict[str, Stream]:
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
+ @cache_in_self
+ def get_federation_ratelimiter(self) -> FederationRateLimiter:
+ return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+
+ @cache_in_self
+ def get_module_api(self) -> ModuleApi:
+ return ModuleApi(self, self.get_auth_handler())
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 395ac5ab02..3ce25bb012 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -12,19 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
from enum import Enum
-from twisted.internet import defer
-
-from synapse.storage.state import StateFilter
-
-MYPY = False
-if MYPY:
- import synapse.server
-
-logger = logging.getLogger(__name__)
-
class RegistrationBehaviour(Enum):
"""
@@ -34,35 +23,3 @@ class RegistrationBehaviour(Enum):
ALLOW = "allow"
SHADOW_BAN = "shadow_ban"
DENY = "deny"
-
-
-class SpamCheckerApi:
- """A proxy object that gets passed to spam checkers so they can get
- access to rooms and other relevant information.
- """
-
- def __init__(self, hs: "synapse.server.HomeServer"):
- self.hs = hs
-
- self._store = hs.get_datastore()
-
- @defer.inlineCallbacks
- def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred:
- """Gets state events for the given room.
-
- Args:
- room_id: The room ID to get state events in.
- types: The event type and state key (using None
- to represent 'any') of the room state to acquire.
-
- Returns:
- twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
- The filtered state events in the room.
- """
- state_ids = yield defer.ensureDeferred(
- self._store.get_filtered_current_state_ids(
- room_id=room_id, state_filter=StateFilter.from_types(types)
- )
- )
- state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
- return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56d6afb863..5b0900aa3c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -13,43 +13,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import heapq
import logging
-from collections import namedtuple
+from collections import defaultdict, namedtuple
from typing import (
+ Any,
Awaitable,
+ Callable,
+ DefaultDict,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
+ Tuple,
Union,
- cast,
overload,
)
import attr
from frozendict import frozendict
-from prometheus_client import Histogram
+from prometheus_client import Counter, Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.logging.context import ContextResourceUsage
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
-from synapse.types import Collection, MutableStateMap, StateMap
-from synapse.util import Clock
+from synapse.types import Collection, StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
-
+metrics_logger = logging.getLogger("synapse.state.metrics")
# Metrics for number of state groups involved in a resolution.
state_groups_histogram = Histogram(
@@ -449,19 +452,44 @@ class StateHandler:
state_map = {ev.event_id: ev for st in state_sets for ev in st}
- with Measure(self.clock, "state._resolve_events"):
- new_state = await resolve_events_with_store(
- self.clock,
- event.room_id,
- room_version,
- state_set_ids,
- event_map=state_map,
- state_res_store=StateResolutionStore(self.store),
- )
+ new_state = await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version,
+ state_set_ids,
+ event_map=state_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
+@attr.s(slots=True)
+class _StateResMetrics:
+ """Keeps track of some usage metrics about state res."""
+
+ # System and User CPU time, in seconds
+ cpu_time = attr.ib(type=float, default=0.0)
+
+ # time spent on database transactions (excluding scheduling time). This roughly
+ # corresponds to the amount of work done on the db server, excluding event fetches.
+ db_time = attr.ib(type=float, default=0.0)
+
+ # number of events fetched from the db.
+ db_events = attr.ib(type=int, default=0)
+
+
+_biggest_room_by_cpu_counter = Counter(
+ "synapse_state_res_cpu_for_biggest_room_seconds",
+ "CPU time spent performing state resolution for the single most expensive "
+ "room for state resolution",
+)
+_biggest_room_by_db_counter = Counter(
+ "synapse_state_res_db_for_biggest_room_seconds",
+ "Database time spent performing state resolution for the single most "
+ "expensive room for state resolution",
+)
+
+
class StateResolutionHandler:
"""Responsible for doing state conflict resolution.
@@ -472,10 +500,9 @@ class StateResolutionHandler:
def __init__(self, hs):
self.clock = hs.get_clock()
- # dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+ # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
@@ -485,6 +512,17 @@ class StateResolutionHandler:
reset_expiry_on_get=True,
)
+ #
+ # stuff for tracking time spent on state-res by room
+ #
+
+ # tracks the amount of work done on state res per room
+ self._state_res_metrics = defaultdict(
+ _StateResMetrics
+ ) # type: DefaultDict[str, _StateResMetrics]
+
+ self.clock.looping_call(self._report_metrics, 120 * 1000)
+
@log_function
async def resolve_state_groups(
self,
@@ -519,57 +557,26 @@ class StateResolutionHandler:
Returns:
The resolved state
"""
- logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
-
group_names = frozenset(state_groups_ids.keys())
with (await self.resolve_linearizer.queue(group_names)):
- if self._state_cache is not None:
- cache = self._state_cache.get(group_names, None)
- if cache:
- return cache
+ cache = self._state_cache.get(group_names, None)
+ if cache:
+ return cache
logger.info(
- "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
+ "Resolving state for %s with groups %s", room_id, list(group_names),
)
state_groups_histogram.observe(len(state_groups_ids))
- # start by assuming we won't have any conflicted state, and build up the new
- # state map by iterating through the state groups. If we discover a conflict,
- # we give up and instead use `resolve_events_with_store`.
- #
- # XXX: is this actually worthwhile, or should we just let
- # resolve_events_with_store do it?
- new_state = {} # type: MutableStateMap[str]
- conflicted_state = False
- for st in state_groups_ids.values():
- for key, e_id in st.items():
- if key in new_state:
- conflicted_state = True
- break
- new_state[key] = e_id
- if conflicted_state:
- break
-
- if conflicted_state:
- logger.info("Resolving conflicted state for %r", room_id)
- with Measure(self.clock, "state._resolve_events"):
- # resolve_events_with_store returns a StateMap, but we can
- # treat it as a MutableStateMap as it is above. It isn't
- # actually mutated anymore (and is frozen in
- # _make_state_cache_entry below).
- new_state = cast(
- MutableStateMap,
- await resolve_events_with_store(
- self.clock,
- room_id,
- room_version,
- list(state_groups_ids.values()),
- event_map=event_map,
- state_res_store=state_res_store,
- ),
- )
+ new_state = await self.resolve_events_with_store(
+ room_id,
+ room_version,
+ list(state_groups_ids.values()),
+ event_map=event_map,
+ state_res_store=state_res_store,
+ )
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@@ -579,11 +586,118 @@ class StateResolutionHandler:
with Measure(self.clock, "state.create_group_ids"):
cache = _make_state_cache_entry(new_state, state_groups_ids)
- if self._state_cache is not None:
- self._state_cache[group_names] = cache
+ self._state_cache[group_names] = cache
return cache
+ async def resolve_events_with_store(
+ self,
+ room_id: str,
+ room_version: str,
+ state_sets: Sequence[StateMap[str]],
+ event_map: Optional[Dict[str, EventBase]],
+ state_res_store: "StateResolutionStore",
+ ) -> StateMap[str]:
+ """
+ Args:
+ room_id: the room we are working in
+
+ room_version: Version of the room
+
+ state_sets: List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+
+ event_map:
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_res_store.
+
+ state_res_store: a place to fetch events from
+
+ Returns:
+ a map from (type, state_key) to event_id.
+ """
+ try:
+ with Measure(self.clock, "state._resolve_events") as m:
+ v = KNOWN_ROOM_VERSIONS[room_version]
+ if v.state_res == StateResolutionVersions.V1:
+ return await v1.resolve_events_with_store(
+ room_id, state_sets, event_map, state_res_store.get_events
+ )
+ else:
+ return await v2.resolve_events_with_store(
+ self.clock,
+ room_id,
+ room_version,
+ state_sets,
+ event_map,
+ state_res_store,
+ )
+ finally:
+ self._record_state_res_metrics(room_id, m.get_resource_usage())
+
+ def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage):
+ room_metrics = self._state_res_metrics[room_id]
+ room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime
+ room_metrics.db_time += rusage.db_txn_duration_sec
+ room_metrics.db_events += rusage.evt_db_fetch_count
+
+ def _report_metrics(self):
+ if not self._state_res_metrics:
+ # no state res has happened since the last iteration: don't bother logging.
+ return
+
+ self._report_biggest(
+ lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter,
+ )
+
+ self._report_biggest(
+ lambda i: i.db_time, "DB time", _biggest_room_by_db_counter,
+ )
+
+ self._state_res_metrics.clear()
+
+ def _report_biggest(
+ self,
+ extract_key: Callable[[_StateResMetrics], Any],
+ metric_name: str,
+ prometheus_counter_metric: Counter,
+ ) -> None:
+ """Report metrics on the biggest rooms for state res
+
+ Args:
+ extract_key: a callable which, given a _StateResMetrics, extracts a single
+ metric to sort by.
+ metric_name: the name of the metric we have extracted, for the log line
+ prometheus_counter_metric: a prometheus metric recording the sum of the
+ the extracted metric
+ """
+ n_to_log = 10
+ if not metrics_logger.isEnabledFor(logging.DEBUG):
+ # only need the most expensive if we don't have debug logging, which
+ # allows nlargest() to degrade to max()
+ n_to_log = 1
+
+ items = self._state_res_metrics.items()
+
+ # log the N biggest rooms
+ biggest = heapq.nlargest(
+ n_to_log, items, key=lambda i: extract_key(i[1])
+ ) # type: List[Tuple[str, _StateResMetrics]]
+ metrics_logger.debug(
+ "%i biggest rooms for state-res by %s: %s",
+ len(biggest),
+ metric_name,
+ ["%s (%gs)" % (r, extract_key(m)) for (r, m) in biggest],
+ )
+
+ # report info on the single biggest to prometheus
+ _, biggest_metrics = biggest[0]
+ prometheus_counter_metric.inc(extract_key(biggest_metrics))
+
def _make_state_cache_entry(
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
@@ -624,7 +738,7 @@ def _make_state_cache_entry(
# failing that, look for the closest match.
prev_group = None
- delta_ids = None
+ delta_ids = None # type: Optional[StateMap[str]]
for old_group, old_state in state_groups_ids.items():
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
@@ -637,47 +751,6 @@ def _make_state_cache_entry(
)
-def resolve_events_with_store(
- clock: Clock,
- room_id: str,
- room_version: str,
- state_sets: Sequence[StateMap[str]],
- event_map: Optional[Dict[str, EventBase]],
- state_res_store: "StateResolutionStore",
-) -> Awaitable[StateMap[str]]:
- """
- Args:
- room_id: the room we are working in
-
- room_version: Version of the room
-
- state_sets: List of dicts of (type, state_key) -> event_id,
- which are the different state groups to resolve.
-
- event_map:
- a dict from event_id to event, for any events that we happen to
- have in flight (eg, those currently being persisted). This will be
- used as a starting point fof finding the state we need; any missing
- events will be requested via state_map_factory.
-
- If None, all events will be fetched via state_res_store.
-
- state_res_store: a place to fetch events from
-
- Returns:
- a map from (type, state_key) to event_id.
- """
- v = KNOWN_ROOM_VERSIONS[room_version]
- if v.state_res == StateResolutionVersions.V1:
- return v1.resolve_events_with_store(
- room_id, state_sets, event_map, state_res_store.get_events
- )
- else:
- return v2.resolve_events_with_store(
- clock, room_id, room_version, state_sets, event_map, state_res_store
- )
-
-
@attr.s(slots=True)
class StateResolutionStore:
"""Interface that allows state resolution algorithms to access the database
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 79ec8f119d..0ba3a025cf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
overload,
)
+import attr
from prometheus_client import Histogram
from typing_extensions import Literal
@@ -90,13 +91,17 @@ def make_pool(
return adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
- cp_openfun=engine.on_new_connection,
+ cp_openfun=lambda conn: engine.on_new_connection(
+ LoggingDatabaseConnection(conn, engine, "on_new_connection")
+ ),
**db_config.config.get("args", {})
)
def make_conn(
- db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ db_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
+ default_txn_name: str,
) -> Connection:
"""Make a new connection to the database and return it.
@@ -109,11 +114,60 @@ def make_conn(
for k, v in db_config.config.get("args", {}).items()
if not k.startswith("cp_")
}
- db_conn = engine.module.connect(**db_params)
+ native_db_conn = engine.module.connect(**db_params)
+ db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
engine.on_new_connection(db_conn)
return db_conn
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+ """A wrapper around a database connection that returns `LoggingTransaction`
+ as its cursor class.
+
+ This is mainly used on startup to ensure that queries get logged correctly
+ """
+
+ conn = attr.ib(type=Connection)
+ engine = attr.ib(type=BaseDatabaseEngine)
+ default_txn_name = attr.ib(type=str)
+
+ def cursor(
+ self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+ ) -> "LoggingTransaction":
+ if not txn_name:
+ txn_name = self.default_txn_name
+
+ return LoggingTransaction(
+ self.conn.cursor(),
+ name=txn_name,
+ database_engine=self.engine,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
+ )
+
+ def close(self) -> None:
+ self.conn.close()
+
+ def commit(self) -> None:
+ self.conn.commit()
+
+ def rollback(self, *args, **kwargs) -> None:
+ self.conn.rollback(*args, **kwargs)
+
+ def __enter__(self) -> "Connection":
+ self.conn.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ return self.conn.__exit__(exc_type, exc_value, traceback)
+
+ # Proxy through any unknown lookups to the DB conn class.
+ def __getattr__(self, name):
+ return getattr(self.conn, name)
+
+
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
def close(self) -> None:
self.txn.close()
+ def __enter__(self) -> "LoggingTransaction":
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
class PerformanceCounters:
def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
def new_transaction(
self,
- conn: Connection,
+ conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
@@ -403,6 +463,24 @@ class DatabasePool:
*args: Any,
**kwargs: Any
) -> R:
+ """Start a new database transaction with the given connection.
+
+ Note: The given func may be called multiple times under certain
+ failure modes. This is normally fine when in a standard transaction,
+ but care must be taken if the connection is in `autocommit` mode that
+ the function will correctly handle being aborted and retried half way
+ through its execution.
+
+ Args:
+ conn
+ desc
+ after_callbacks
+ exception_callbacks
+ func
+ *args
+ **kwargs
+ """
+
start = monotonic_time()
txn_id = self._TXN_ID
@@ -418,12 +496,10 @@ class DatabasePool:
i = 0
N = 5
while True:
- cursor = LoggingTransaction(
- conn.cursor(),
- name,
- self.engine,
- after_callbacks,
- exception_callbacks,
+ cursor = conn.cursor(
+ txn_name=name,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
)
try:
r = func(cursor, *args, **kwargs)
@@ -508,7 +584,12 @@ class DatabasePool:
sql_txn_timer.labels(desc).observe(duration)
async def runInteraction(
- self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ self,
+ desc: str,
+ func: "Callable[..., R]",
+ *args: Any,
+ db_autocommit: bool = False,
+ **kwargs: Any
) -> R:
"""Starts a transaction on the database and runs a given function
@@ -518,6 +599,18 @@ class DatabasePool:
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
+ db_autocommit: Whether to run the function in "autocommit" mode,
+ i.e. outside of a transaction. This is useful for transactions
+ that are only a single query.
+
+ Currently, this is only implemented for Postgres. SQLite will still
+ run the function inside a transaction.
+
+ WARNING: This means that if func fails half way through then
+ the changes will *not* be rolled back. `func` may also get
+ called multiple times if the transaction is retried, so must
+ correctly handle that case.
+
args: positional args to pass to `func`
kwargs: named args to pass to `func`
@@ -538,6 +631,7 @@ class DatabasePool:
exception_callbacks,
func,
*args,
+ db_autocommit=db_autocommit,
**kwargs
)
@@ -551,7 +645,11 @@ class DatabasePool:
return cast(R, result)
async def runWithConnection(
- self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ self,
+ func: "Callable[..., R]",
+ *args: Any,
+ db_autocommit: bool = False,
+ **kwargs: Any
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
@@ -560,6 +658,9 @@ class DatabasePool:
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
args: positional args to pass to `func`
+ db_autocommit: Whether to run the function in "autocommit" mode,
+ i.e. outside of a transaction. This is useful for transaction
+ that are only a single query. Currently only affects postgres.
kwargs: named args to pass to `func`
Returns:
@@ -575,6 +676,13 @@ class DatabasePool:
start_time = monotonic_time()
def inner_func(conn, *args, **kwargs):
+ # We shouldn't be in a transaction. If we are then something
+ # somewhere hasn't committed after doing work. (This is likely only
+ # possible during startup, as `run*` will ensure changes are
+ # committed/rolled back before putting the connection back in the
+ # pool).
+ assert not self.engine.in_transaction(conn)
+
with LoggingContext("runWithConnection", parent_context) as context:
sched_duration_sec = monotonic_time() - start_time
sql_scheduling_timer.observe(sched_duration_sec)
@@ -584,7 +692,17 @@ class DatabasePool:
logger.debug("Reconnecting closed database connection")
conn.reconnect()
- return func(conn, *args, **kwargs)
+ try:
+ if db_autocommit:
+ self.engine.attempt_to_set_autocommit(conn, True)
+
+ db_conn = LoggingDatabaseConnection(
+ conn, self.engine, "runWithConnection"
+ )
+ return func(db_conn, *args, **kwargs)
+ finally:
+ if db_autocommit:
+ self.engine.attempt_to_set_autocommit(conn, False)
return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
@@ -1621,7 +1739,7 @@ class DatabasePool:
def get_cache_dict(
self,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
table: str,
entity_column: str,
stream_column: str,
@@ -1642,9 +1760,7 @@ class DatabasePool:
"limit": limit,
}
- sql = self.engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
+ txn = db_conn.cursor(txn_name="get_cache_dict")
txn.execute(sql, (int(max_value),))
cache = {row[0]: int(row[1]) for row in txn}
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases:
db_name = database_config.name
engine = create_engine(database_config.config)
- with make_conn(database_config, engine) as db_conn:
+ with make_conn(database_config, engine, "startup") as db_conn:
logger.info("[database config %r]: Checking database server", db_name)
engine.check_database(db_conn)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 2ae2fbd5d7..9b16f45f3e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import calendar
import logging
-import time
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
@@ -160,19 +158,25 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
+ # We set the `writers` to an empty list here as we don't care about
+ # missing updates over restarts, as we'll not have anything in our
+ # caches to invalidate. (This reduces the amount of writes to the DB
+ # that happen).
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
- instance_name="master",
+ stream_name="caches",
+ instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
+ writers=[],
)
else:
self._cache_id_gen = None
- super(DataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._presence_on_startup = self._get_active_presence(db_conn)
@@ -262,9 +266,6 @@ class DataStore(
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
- # Used in _generate_user_daily_visits to keep track of progress
- self._last_user_visit_update = self._get_start_of_day()
-
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -283,7 +284,6 @@ class DataStore(
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
- sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
@@ -295,192 +295,6 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- async def count_daily_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 24 hours.
- """
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return await self.db_pool.runInteraction(
- "count_daily_users", self._count_users, yesterday
- )
-
- async def count_monthly_users(self) -> int:
- """
- Counts the number of users who used this homeserver in the last 30 days.
- Note this method is intended for phonehome metrics only and is different
- from the mau figure in synapse.storage.monthly_active_users which,
- amongst other things, includes a 3 day grace period before a user counts.
- """
- thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return await self.db_pool.runInteraction(
- "count_monthly_users", self._count_users, thirty_days_ago
- )
-
- def _count_users(self, txn, time_from):
- """
- Returns number of users seen in the past time_from period
- """
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
- txn.execute(sql, (time_from,))
- (count,) = txn.fetchone()
- return count
-
- async def count_r30_users(self) -> Dict[str, int]:
- """
- Counts the number of 30 day retained users, defined as:-
- * Users who have created their accounts more than 30 days ago
- * Where last seen at most 30 days ago
- * Where account creation and last_seen are > 30 days apart
-
- Returns:
- A mapping of counts globally as well as broken out by platform.
- """
-
- def _count_r30_users(txn):
- thirty_days_in_secs = 86400 * 30
- now = int(self._clock.time())
- thirty_days_ago_in_secs = now - thirty_days_in_secs
-
- sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
- SELECT
- users.name, platform, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen,
- CASE
- WHEN user_agent LIKE '%%Android%%' THEN 'android'
- WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
- WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
- WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
- WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
- ELSE 'unknown'
- END
- AS platform
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND users.appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, platform, users.creation_ts
- ) u GROUP BY platform
- """
-
- results = {}
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- for row in txn:
- if row[0] == "unknown":
- pass
- results[row[0]] = row[1]
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT users.name, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, users.creation_ts
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- (count,) = txn.fetchone()
- results["all"] = count
-
- return results
-
- return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
-
- def _get_start_of_day(self):
- """
- Returns millisecond unixtime for start of UTC day.
- """
- now = time.gmtime()
- today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
- return today_start * 1000
-
- async def generate_user_daily_visits(self) -> None:
- """
- Generates daily visit data for use in cohort/ retention analysis
- """
-
- def _generate_user_daily_visits(txn):
- logger.info("Calling _generate_user_daily_visits")
- today_start = self._get_start_of_day()
- a_day_in_milliseconds = 24 * 60 * 60 * 1000
- now = self.clock.time_msec()
-
- sql = """
- INSERT INTO user_daily_visits (user_id, device_id, timestamp)
- SELECT u.user_id, u.device_id, ?
- FROM user_ips AS u
- LEFT JOIN (
- SELECT user_id, device_id, timestamp FROM user_daily_visits
- WHERE timestamp = ?
- ) udv
- ON u.user_id = udv.user_id AND u.device_id=udv.device_id
- INNER JOIN users ON users.name=u.user_id
- WHERE last_seen > ? AND last_seen <= ?
- AND udv.timestamp IS NULL AND users.is_guest=0
- AND users.appservice_id IS NULL
- GROUP BY u.user_id, u.device_id
- """
-
- # This means that the day has rolled over but there could still
- # be entries from the previous day. There is an edge case
- # where if the user logs in at 23:59 and overwrites their
- # last_seen at 00:01 then they will not be counted in the
- # previous day's stats - it is important that the query is run
- # often to minimise this case.
- if today_start > self._last_user_visit_update:
- yesterday_start = today_start - a_day_in_milliseconds
- txn.execute(
- sql,
- (
- yesterday_start,
- yesterday_start,
- self._last_user_visit_update,
- today_start,
- ),
- )
- self._last_user_visit_update = today_start
-
- txn.execute(
- sql, (today_start, today_start, self._last_user_visit_update, now)
- )
- # Update _last_user_visit_update to now. The reason to do this
- # rather just clamping to the beginning of the day is to limit
- # the size of the join - meaning that the query can be run more
- # frequently
- self._last_user_visit_update = now
-
- await self.db_pool.runInteraction(
- "generate_user_daily_visits", _generate_user_daily_visits
- )
-
async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 4436b1a83d..49ee23470d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc
import logging
from typing import Dict, List, Optional, Tuple
+from synapse.api.constants import AccountDataTypes
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -29,22 +30,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
- super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
@@ -293,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore):
self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
) -> bool:
ignored_account_data = await self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list",
+ AccountDataTypes.IGNORED_USER_LIST,
ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
return False
- return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ try:
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
+ except TypeError:
+ # The type of the ignored_users field is invalid.
+ return False
class AccountDataStore(AccountDataWorkerStore):
@@ -315,7 +318,7 @@ class AccountDataStore(AccountDataWorkerStore):
],
)
- super(AccountDataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream
@@ -341,7 +344,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -389,7 +392,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 454c0bc50c..85f6b1e3fd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
- super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..4bb2b9c28c 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.frozenutils import frozendict_json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = encode_json(
+ pruned_json = frozendict_json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
- pruned_json = encode_json(
+ pruned_json = frozendict_json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c2fc847fbc..a25a888443 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"user_ips_device_index",
@@ -351,16 +351,70 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.user_ips_max_age = hs.config.user_ips_max_age
+
+ if hs.config.run_background_tasks and self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.db_pool.updates.has_completed_background_update(
+ "devices_last_seen"
+ ):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
+ )
+ """
+
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.db_pool.runInteraction(
+ "_prune_old_user_ips", _prune_old_user_ips_txn
+ )
+
+
+class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
- super(ClientIpStore, self).__init__(database, db_conn, hs)
-
- self.user_ips_max_age = hs.config.user_ips_max_age
+ super().__init__(database, db_conn, hs)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
@@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"before", "shutdown", self._update_client_ips_batch
)
- if self.user_ips_max_age:
- self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
-
async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
@@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
}
for (access_token, ip), (user_agent, last_seen) in results.items()
]
-
- @wrap_as_background_process("prune_old_user_ips")
- async def _prune_old_user_ips(self):
- """Removes entries in user IPs older than the configured period.
- """
-
- if self.user_ips_max_age is None:
- # Nothing to do
- return
-
- if not await self.db_pool.updates.has_completed_background_update(
- "devices_last_seen"
- ):
- # Only start pruning if we have finished populating the devices
- # last seen info.
- return
-
- # We do a slightly funky SQL delete to ensure we don't try and delete
- # too much at once (as the table may be very large from before we
- # started pruning).
- #
- # This works by finding the max last_seen that is less than the given
- # time, but has no more than N rows before it, deleting all rows with
- # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
- # returns exactly one row).
- sql = """
- DELETE FROM user_ips
- WHERE last_seen <= (
- SELECT COALESCE(MAX(last_seen), -1)
- FROM (
- SELECT last_seen FROM user_ips
- WHERE last_seen <= ?
- ORDER BY last_seen ASC
- LIMIT 5000
- ) AS u
- )
- """
-
- timestamp = self.clock.time_msec() - self.user_ips_max_age
-
- def _prune_old_user_ips_txn(txn):
- txn.execute(sql, (timestamp,))
-
- await self.db_pool.runInteraction(
- "_prune_old_user_ips", _prune_old_user_ips_txn
- )
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 0044433110..d42faa3f1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index",
@@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 306fc6947c..2d0a6408b5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with await self._device_list_id_gen.get_next() as stream_id:
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -698,10 +698,84 @@ class DeviceWorkerStore(SQLBaseStore):
_mark_remote_user_device_list_as_unsubscribed_txn,
)
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ # FIXME: make sure device ID still exists in devices table
+ row = await self.db_pool.simple_select_one(
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcols=["device_id", "device_data"],
+ allow_none=True,
+ )
+ return (
+ (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
+ )
+
+ def _store_dehydrated_device_txn(
+ self, txn, user_id: str, device_id: str, device_data: str
+ ) -> Optional[str]:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ retcol="device_id",
+ allow_none=True,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id},
+ values={"device_id": device_id, "device_data": device_data},
+ )
+ return old_device_id
+
+ async def store_dehydrated_device(
+ self, user_id: str, device_id: str, device_data: JsonDict
+ ) -> Optional[str]:
+ """Store a dehydrated device for a user.
+
+ Args:
+ user_id: the user that we are storing the device for
+ device_id: the ID of the dehydrated device
+ device_data: the dehydrated device information
+ Returns:
+ device id of the user's previous dehydrated device, if any
+ """
+ return await self.db_pool.runInteraction(
+ "store_dehydrated_device_txn",
+ self._store_dehydrated_device_txn,
+ user_id,
+ device_id,
+ json_encoder.encode(device_data),
+ )
+
+ async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
+ """Remove a dehydrated device.
+
+ Args:
+ user_id: the user that the dehydrated device belongs to
+ device_id: the ID of the dehydrated device
+ """
+ count = await self.db_pool.simple_delete(
+ "dehydrated_devices",
+ {"user_id": user_id, "device_id": device_id},
+ desc="remove_dehydrated_device",
+ )
+ return count >= 1
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
@@ -826,7 +900,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
@@ -837,7 +911,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
async def store_device(
- self, user_id: str, device_id: str, initial_device_display_name: str
+ self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -955,7 +1029,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def update_remote_device_list_cache_entry(
- self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+ self, user_id: str, device_id: str, content: JsonDict, stream_id: str
) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
@@ -983,7 +1057,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
content: JsonDict,
- stream_id: int,
+ stream_id: str,
) -> None:
if content.get("deleted"):
self.db_pool.simple_delete_txn(
@@ -1093,7 +1167,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1108,7 +1182,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c8df0bcb3f..359dc6e968 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
+ async def set_e2e_fallback_keys(
+ self, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
+ # fallback_keys will usually only have one item in it, so using a for
+ # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+ # FIXME: make sure that only one key per algorithm is uploaded
+ for key_id, fallback_key in fallback_keys.items():
+ algorithm, key_id = key_id.split(":", 1)
+ await self.db_pool.simple_upsert(
+ "e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ desc="set_e2e_fallback_key",
+ )
+
+ @cached(max_entries=10000)
+ async def get_e2e_unused_fallback_key_types(
+ self, user_id: str, device_id: str
+ ) -> List[str]:
+ """Returns the fallback key types that have an unused key.
+
+ Args:
+ user_id: the user whose keys are being queried
+ device_id: the device whose keys are being queried
+
+ Returns:
+ a list of key types
+ """
+ return await self.db_pool.simple_select_onecol(
+ "e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+ retcol="algorithm",
+ desc="get_e2e_unused_fallback_key_types",
+ )
+
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
@@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
+ fallback_sql = (
+ "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
result = {}
delete = []
+ used_fallbacks = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
+ otk_row = txn.fetchone()
+ if otk_row is not None:
+ key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
+ else:
+ # no one-time key available, so see if there's a fallback
+ # key
+ txn.execute(fallback_sql, (user_id, device_id, algorithm))
+ fallback_row = txn.fetchone()
+ if fallback_row is not None:
+ key_id, key_json, used = fallback_row
+ device_result[algorithm + ":" + key_id] = key_json
+ if not used:
+ used_fallbacks.append(
+ (user_id, device_id, algorithm, key_id)
+ )
+
+ # drop any one-time keys that were claimed
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ # mark fallback keys as used
+ for user_id, device_id, algorithm, key_id in used_fallbacks:
+ self.db_pool.simple_update_txn(
+ txn,
+ "e2e_fallback_keys_json",
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
+ {"used": True},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
+
return result
return await self.db_pool.runInteraction(
@@ -754,6 +844,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="dehydrated_devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ )
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
@@ -831,7 +934,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data
"""
- with await self._cross_signing_id_gen.get_next() as stream_id:
+ async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4c3c162acf..6d3689c09e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventFederationStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7805fb814e..80f3b4d740 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union
import attr
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -68,17 +68,13 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None
self.stream_ordering_day_ago = None
- cur = LoggingTransaction(
- db_conn.cursor(),
- name="_find_stream_orderings_for_times_txn",
- database_engine=self.database_engine,
- )
+ cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -661,7 +657,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventPushActionsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a80f419e3..b4abd961b9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -52,16 +52,6 @@ event_counter = Counter(
)
-def encode_json(json_object):
- """
- Encode a Python object as JSON and return it in a Unicode string.
- """
- out = frozendict_json_encoder.encode(json_object)
- if isinstance(out, bytes):
- out = out.decode("utf8")
- return out
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -156,15 +146,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
@@ -341,6 +331,10 @@ class PersistEventsStore:
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+ # stream orderings should have been assigned by now
+ assert min_stream_order
+ assert max_stream_order
+
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -743,7 +737,9 @@ class PersistEventsStore:
logger.exception("")
raise
- metadata_json = encode_json(event.internal_metadata.get_dict())
+ metadata_json = frozendict_json_encoder.encode(
+ event.internal_metadata.get_dict()
+ )
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -797,10 +793,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
- "internal_metadata": encode_json(
+ "internal_metadata": frozendict_json_encoder.encode(
event.internal_metadata.get_dict()
),
- "json": encode_json(event_dict(event)),
+ "json": frozendict_json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts
@@ -1108,6 +1104,10 @@ class PersistEventsStore:
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
+
+ def str_or_none(val: Any) -> Optional[str]:
+ return val if isinstance(val, str) else None
+
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
@@ -1118,8 +1118,8 @@ class PersistEventsStore:
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
+ "display_name": str_or_none(event.content.get("displayname")),
+ "avatar_url": str_or_none(event.content.get("avatar_url")),
}
for event in events
],
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e53c6373a8..5e4af2eb51 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 17f5997b89..b7ed8ca6ab 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import division
-
import itertools
import logging
import threading
@@ -76,8 +74,15 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
+ # Whether to use dedicated DB threads for event fetching. This is only used
+ # if there are multiple DB threads available. When used will lock the DB
+ # thread for periods of time (so unit tests want to disable this when they
+ # run DB transactions on the main thread). See EVENT_QUEUE_* for more
+ # options controlling this.
+ USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
+
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
@@ -85,21 +90,25 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ stream_name="events",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
+ writers=hs.config.worker.writers.events,
)
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ stream_name="backfill",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
+ writers=hs.config.worker.writers.events,
)
else:
# We shouldn't be running in worker mode with SQLite, but its useful
@@ -520,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore):
if not event_list:
single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
self._event_fetch_ongoing -= 1
return
else:
@@ -710,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
+ original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
event_map[event_id] = original_ev
@@ -777,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore):
* event_id (str)
+ * stream_ordering (int): stream ordering for this event
+
* json (str): json-encoded event structure
* internal_metadata (str): json-encoded internal metadata dict
@@ -809,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore):
sql = """\
SELECT
e.event_id,
- e.internal_metadata,
- e.json,
- e.format_version,
+ e.stream_ordering,
+ ej.internal_metadata,
+ ej.json,
+ ej.format_version,
r.room_version,
rej.reason
- FROM event_json as e
- LEFT JOIN rooms r USING (room_id)
+ FROM events AS e
+ JOIN event_json AS ej USING (event_id)
+ LEFT JOIN rooms r ON r.room_id = e.room_id
LEFT JOIN rejections as rej USING (event_id)
WHERE """
@@ -829,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore):
event_id = row[0]
event_dict[event_id] = {
"event_id": event_id,
- "internal_metadata": row[1],
- "json": row[2],
- "format_version": row[3],
- "room_version_id": row[4],
- "rejected_reason": row[5],
+ "stream_ordering": row[1],
+ "internal_metadata": row[2],
+ "json": row[3],
+ "format_version": row[4],
+ "room_version_id": row[5],
+ "rejected_reason": row[6],
"redactions": [],
}
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with await self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1d76c761a6..cc538c5c10 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -24,9 +24,7 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MediaRepositoryBackgroundUpdateStore, self).__init__(
- database, db_conn, hs
- )
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
@@ -94,7 +92,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 686052bd83..0acf0617ca 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,17 +12,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import typing
-from collections import Counter
+import calendar
+import logging
+import time
+from typing import Dict
-from synapse.metrics import BucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics import GaugeBucketCollector
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+logger = logging.getLogger(__name__)
+
+# Collect metrics on the number of forward extremities that exist.
+_extremities_collecter = GaugeBucketCollector(
+ "synapse_forward_extremities",
+ "Number of rooms on the server with the given number of forward extremities"
+ " or fewer",
+ buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500],
+)
+
+# we also expose metrics on the "number of excess extremity events", which is
+# (E-1)*N, where E is the number of extremities and N is the number of state
+# events in the room. This is an approximation to the number of state events
+# we could remove from state resolution by reducing the graph to a single
+# forward extremity.
+_excess_state_events_collecter = GaugeBucketCollector(
+ "synapse_excess_extremity_events",
+ "Number of rooms on the server with the given number of excess extremity "
+ "events, or fewer",
+ buckets=[0] + [1 << n for n in range(12)],
+)
+
class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""Functions to pull various metrics from the DB, for e.g. phone home
@@ -32,40 +56,37 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- # Collect metrics on the number of forward extremities that exist.
- # Counter of number of extremities to count
- self._current_forward_extremities_amount = (
- Counter()
- ) # type: typing.Counter[int]
-
- BucketCollector(
- "synapse_forward_extremities",
- lambda: self._current_forward_extremities_amount,
- buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
- )
-
# Read the extrems every 60 minutes
- def read_forward_extremities():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "read_forward_extremities", self._read_forward_extremities
- )
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
- hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+ @wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
def fetch(txn):
txn.execute(
"""
- select count(*) c from event_forward_extremities
- group by room_id
+ SELECT t1.c, t2.c
+ FROM (
+ SELECT room_id, COUNT(*) c FROM event_forward_extremities
+ GROUP BY room_id
+ ) t1 LEFT JOIN (
+ SELECT room_id, COUNT(*) c FROM current_state_events
+ GROUP BY room_id
+ ) t2 ON t1.room_id = t2.room_id
"""
)
return txn.fetchall()
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
- self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+ _extremities_collecter.update_data(x[0] for x in res)
+
+ _excess_state_events_collecter.update_data(
+ (x[0] - 1) * x[1] for x in res if x[1]
+ )
async def count_daily_messages(self):
"""
@@ -120,3 +141,190 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
+
+ async def count_daily_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 24 hours.
+ """
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+ return await self.db_pool.runInteraction(
+ "count_daily_users", self._count_users, yesterday
+ )
+
+ async def count_monthly_users(self) -> int:
+ """
+ Counts the number of users who used this homeserver in the last 30 days.
+ Note this method is intended for phonehome metrics only and is different
+ from the mau figure in synapse.storage.monthly_active_users which,
+ amongst other things, includes a 3 day grace period before a user counts.
+ """
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ return await self.db_pool.runInteraction(
+ "count_monthly_users", self._count_users, thirty_days_ago
+ )
+
+ def _count_users(self, txn, time_from):
+ """
+ Returns number of users seen in the past time_from period
+ """
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+ txn.execute(sql, (time_from,))
+ (count,) = txn.fetchone()
+ return count
+
+ async def count_r30_users(self) -> Dict[str, int]:
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
+ """
+
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] == "unknown":
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ (count,) = txn.fetchone()
+ results["all"] = count
+
+ return results
+
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = time.gmtime()
+ today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+ return today_start * 1000
+
+ @wrap_as_background_process("generate_user_daily_visits")
+ async def generate_user_daily_visits(self) -> None:
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self._clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(
+ sql,
+ (
+ yesterday_start,
+ yesterday_start,
+ self._last_user_visit_update,
+ today_start,
+ ),
+ )
+ self._last_user_visit_update = today_start
+
+ txn.execute(
+ sql, (today_start, today_start, self._last_user_visit_update, now)
+ )
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ await self.db_pool.runInteraction(
+ "generate_user_daily_visits", _generate_user_daily_visits
+ )
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 1d793d3deb..c66f558567 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -28,10 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
+ self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+ self._max_mau_value = hs.config.max_mau_value
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -41,7 +44,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
def _count_users(txn):
- sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+ # Exclude app service users
+ sql = """
+ SELECT COALESCE(count(*), 0)
+ FROM monthly_active_users
+ LEFT JOIN users
+ ON monthly_active_users.user_id=users.name
+ WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
+ """
txn.execute(sql)
(count,) = txn.fetchone()
return count
@@ -117,60 +127,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
desc="user_last_seen_monthly_active",
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
-
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._mau_stats_only = hs.config.mau_stats_only
- self._max_mau_value = hs.config.max_mau_value
-
- # Do not add more reserved users than the total allowable number
- # cur = LoggingTransaction(
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
- def _initialise_reserved_users(self, txn, threepids):
- """Ensures that reserved threepids are accounted for in the MAU table, should
- be called on start up.
-
- Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
- """
-
- # XXX what is this function trying to achieve? It upserts into
- # monthly_active_users for each *registered* reserved mau user, but why?
- #
- # - shouldn't there already be an entry for each reserved user (at least
- # if they have been active recently)?
- #
- # - if it's important that the timestamp is kept up to date, why do we only
- # run this at startup?
-
- for tp in threepids:
- user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
-
- if user_id:
- is_support = self.is_support_user_txn(txn, user_id)
- if not is_support:
- # We do this manually here to avoid hitting #6791
- self.db_pool.simple_upsert_txn(
- txn,
- table="monthly_active_users",
- keyvalues={"user_id": user_id},
- values={"timestamp": int(self._clock.time_msec())},
- )
- else:
- logger.warning("mau limit reserved threepid %s not found in db" % tp)
-
async def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
@@ -250,6 +206,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._mau_stats_only = hs.config.mau_stats_only
+
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
+ def _initialise_reserved_users(self, txn, threepids):
+ """Ensures that reserved threepids are accounted for in the MAU table, should
+ be called on start up.
+
+ Args:
+ txn (cursor):
+ threepids (list[dict]): List of threepid dicts to reserve
+ """
+
+ # XXX what is this function trying to achieve? It upserts into
+ # monthly_active_users for each *registered* reserved mau user, but why?
+ #
+ # - shouldn't there already be an entry for each reserved user (at least
+ # if they have been active recently)?
+ #
+ # - if it's important that the timestamp is kept up to date, why do we only
+ # run this at startup?
+
+ for tp in threepids:
+ user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
+
+ if user_id:
+ is_support = self.is_support_user_txn(txn, user_id)
+ if not is_support:
+ # We do this manually here to avoid hitting #6791
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ values={"timestamp": int(self._clock.time_msec())},
+ )
+ else:
+ logger.warning("mau limit reserved threepid %s not found in db" % tp)
+
async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index d7a03cbf7d..ecfc6717b3 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
The set of state groups that are referenced by deleted events.
"""
+ parsed_token = await RoomStreamToken.parse(self, token)
+
return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
- token,
+ parsed_token,
delete_local_events,
)
- def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
- token = RoomStreamToken.parse(token_str)
-
+ def _purge_history_txn(self, txn, room_id, token, delete_local_events):
# Tables that should be pruned:
# event_auth
# event_backward_extremities
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9790a31998..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -61,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
return rules
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
@@ -68,17 +70,14 @@ class PushRulesWorkerStore(
RoomMemberWorkerStore,
EventsWorkerStore,
SQLBaseStore,
+ metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
- super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -339,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -586,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -617,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
NotFoundError if the rule does not exist.
"""
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
@@ -755,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 4a0d5a320e..c79ddff680 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -31,17 +31,15 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-class ReceiptsWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
- super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -388,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
db_conn, "receipts_linearized", "stream_id"
)
- super(ReceiptsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@@ -526,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- with await self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..a85867936f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
@@ -36,15 +38,33 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
+ # Note: we don't check this sequence for consistency as we'd have to
+ # call `find_max_generated_user_id_localpart` each time, which is
+ # expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
+ self._account_validity = hs.config.account_validity
+ if hs.config.run_background_tasks and self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ if hs.config.run_background_tasks:
+ self.clock.looping_call(
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+ )
+
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
@@ -116,6 +136,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user",
)
+ async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Whether the user account has expired
+ """
+ expiration_ts = await self.get_expiration_ts_for_user(user_id)
+ return expiration_ts is not None and current_ts >= expiration_ts
+
async def set_account_validity_for_user(
self,
user_id: str,
@@ -379,7 +413,7 @@ class RegistrationWorkerStore(SQLBaseStore):
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
- ) -> str:
+ ) -> Optional[str]:
"""Look up a user by their external auth id
Args:
@@ -387,7 +421,7 @@ class RegistrationWorkerStore(SQLBaseStore):
external_id: id on that system
Returns:
- str|None: the mxid of the user, or None if they are not known
+ the mxid of the user, or None if they are not known
"""
return await self.db_pool.simple_select_one_onecol(
table="user_external_ids",
@@ -761,10 +795,82 @@ class RegistrationWorkerStore(SQLBaseStore):
"delete_threepid_session", delete_threepid_session_txn
)
+ @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+ async def cull_expired_threepid_validation_tokens(self) -> None:
+ """Remove threepid validation tokens with expiry dates that have passed"""
+
+ def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ sql = """
+ DELETE FROM threepid_validation_token WHERE
+ expires < ?
+ """
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "cull_expired_threepid_validation_tokens",
+ cull_expired_threepid_validation_tokens_txn,
+ self.clock.time_msec(),
+ )
+
+ async def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db_pool.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ await self.db_pool.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
@@ -892,30 +998,10 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
- self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
-
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
async def add_access_token_to_user(
self,
user_id: str,
@@ -947,6 +1033,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
+ def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn, "access_tokens", {"token": token}, "device_id"
+ )
+
+ self.db_pool.simple_update_txn(
+ txn, "access_tokens", {"token": token}, {"device_id": device_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+ return old_device_id
+
+ async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+ """Sets the device ID associated with an access token.
+
+ Args:
+ token: The access token to modify.
+ device_id: The new device ID.
+ Returns:
+ The old device ID associated with the access token.
+ """
+
+ return await self.db_pool.runInteraction(
+ "set_device_for_access_token",
+ self._set_device_for_access_token_txn,
+ token,
+ device_id,
+ )
+
async def register_user(
self,
user_id: str,
@@ -1430,22 +1546,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def cull_expired_threepid_validation_tokens(self) -> None:
- """Remove threepid validation tokens with expiry dates that have passed"""
-
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
- sql = """
- DELETE FROM threepid_validation_token WHERE
- expires < ?
- """
- txn.execute(sql, (ts,))
-
- await self.db_pool.runInteraction(
- "cull_expired_threepid_validation_tokens",
- cull_expired_threepid_validation_tokens_txn,
- self.clock.time_msec(),
- )
-
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
@@ -1475,61 +1575,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- async def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- await self.db_pool.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self.db_pool.simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id},
- values={"expiration_ts_ms": expiration_ts, "email_sent": False},
- )
-
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 127588ce4c..c0f2af0785 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -69,7 +69,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore):
"count_public_rooms", _count_public_rooms_txn
)
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
+ """
+
+ def f(txn):
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
+
+ return await self.db_pool.runInteraction("get_rooms", f)
+
async def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
@@ -863,7 +875,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1074,7 +1086,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1137,7 +1149,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1204,7 +1216,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1284,7 +1296,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- async def get_room_count(self) -> int:
- """Retrieve the total number of rooms.
- """
-
- def f(txn):
- sql = "SELECT count(*) FROM rooms"
- txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
-
- return await self.db_pool.runInteraction("get_rooms", f)
-
async def add_event_report(
self,
room_id: str,
@@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
desc="add_event_report",
)
+ async def get_event_reports_paginate(
+ self,
+ start: int,
+ limit: int,
+ direction: str = "b",
+ user_id: Optional[str] = None,
+ room_id: Optional[str] = None,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Retrieve a paginated list of event reports
+
+ Args:
+ start: event offset to begin the query from
+ limit: number of rows to retrieve
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`)
+ user_id: search for user_id. Ignored if user_id is None
+ room_id: search for room_id. Ignored if room_id is None
+ Returns:
+ event_reports: json list of event reports
+ count: total number of event reports matching the filter criteria
+ """
+
+ def _get_event_reports_paginate_txn(txn):
+ filters = []
+ args = []
+
+ if user_id:
+ filters.append("er.user_id LIKE ?")
+ args.extend(["%" + user_id + "%"])
+ if room_id:
+ filters.append("er.room_id LIKE ?")
+ args.extend(["%" + room_id + "%"])
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql = """
+ SELECT COUNT(*) as total_event_reports
+ FROM event_reports AS er
+ {}
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.reason,
+ er.content,
+ events.sender,
+ room_aliases.room_alias,
+ event_json.json AS event_json
+ FROM event_reports AS er
+ LEFT JOIN room_aliases
+ ON room_aliases.room_id = er.room_id
+ JOIN events
+ ON events.event_id = er.event_id
+ JOIN event_json
+ ON event_json.event_id = er.event_id
+ {where_clause}
+ ORDER BY er.received_ts {order}
+ LIMIT ?
+ OFFSET ?
+ """.format(
+ where_clause=where_clause, order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+ event_reports = self.db_pool.cursor_to_dict(txn)
+
+ if count > 0:
+ for row in event_reports:
+ try:
+ row["content"] = db_to_json(row["content"])
+ row["event_json"] = db_to_json(row["event_json"])
+ except Exception:
+ continue
+
+ return event_reports, count
+
+ return await self.db_pool.runInteraction(
+ "get_event_reports_paginate", _get_event_reports_paginate_txn
+ )
+
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 91a8b43da3..20fcdaa529 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
@@ -22,12 +21,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
- LoggingTransaction,
- SQLBaseStore,
- db_to_json,
- make_in_list_sql_clause,
-)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
@@ -37,7 +31,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -55,21 +49,22 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Is the current_state_events.membership up to date? Or is the
# background update still running?
self._current_state_events_membership_up_to_date = False
- txn = LoggingTransaction(
- db_conn.cursor(),
- name="_check_safe_current_state_events_membership_updated",
- database_engine=self.database_engine,
+ txn = db_conn.cursor(
+ txn_name="_check_safe_current_state_events_membership_updated"
)
self._check_safe_current_state_events_membership_updated_txn(txn)
txn.close()
- if self.hs.config.metrics_flags.known_servers:
+ if (
+ self.hs.config.run_background_tasks
+ and self.hs.config.metrics_flags.known_servers
+ ):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
run_as_background_process,
@@ -387,7 +382,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# for rooms the server is participating in.
if self._current_state_events_membership_up_to_date:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
@@ -397,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
else:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +403,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+ return frozenset(
+ GetRoomsForUserWithStreamOrdering(
+ room_id, PersistedEventPosition(instance, stream_id)
+ )
+ for room_id, instance, stream_id in txn
+ )
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@@ -819,7 +819,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
@@ -973,7 +973,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_token, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_token, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_origin_server_ts", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
- database_engine.convert_param_style(
- "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
- % (",".join("?" for _ in chunk),)
- ),
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+ % (",".join("?" for _ in chunk),),
[as_id] + chunk,
)
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
row = list(row)
row[12] = token_to_stream_ordering(row[12])
cur.execute(
- database_engine.convert_param_style(
- """
- INSERT into pushers2 (
- id, user_name, access_token, profile_tag, kind,
- app_id, app_display_name, device_display_name,
- pushkey, ts, lang, data, last_stream_ordering, last_success,
- failing_since
- ) values (%s)"""
- % (",".join(["?" for _ in range(len(row))]))
- ),
+ """
+ INSERT into pushers2 (
+ id, user_name, access_token, profile_tag, kind,
+ app_id, app_display_name, device_display_name,
+ pushkey, ts, lang, data, last_stream_ordering, last_success,
+ failing_since
+ ) values (%s)
+ """
+ % (",".join(["?" for _ in range(len(row))])),
row,
)
count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_search_order", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
" VALUES (?, ?)"
)
- sql = database_engine.convert_param_style(sql)
-
cur.execute(sql, ("event_fields_sender_url", progress_json))
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute(
- database_engine.convert_param_style(
- "UPDATE remote_media_cache SET last_access_ts = ?"
- ),
- (int(time.time() * 1000),),
+ "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
)
diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..ccf287971c 100644
--- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
@@ -13,7 +13,7 @@
* limitations under the License.
*/
--- room_id and topoligical_ordering are denormalised from the events table in order to
+-- room_id and topological_ordering are denormalised from the events table in order to
-- make the index work.
CREATE TABLE IF NOT EXISTS event_labels (
event_id TEXT,
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
import logging
+from io import StringIO
from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
logger = logging.getLogger(__name__)
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
select_clause,
)
- if isinstance(database_engine, PostgresEngine):
- cur.execute(sql)
- else:
- cur.executescript(sql)
+ execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' AND state_key LIKE ?
"""
- sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("%:" + config.server_name,))
cur.execute(
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..7851a0a825
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+ user_id TEXT NOT NULL PRIMARY KEY,
+ device_id TEXT NOT NULL,
+ device_data TEXT NOT NULL -- JSON-encoded client-defined data
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+ user_id TEXT NOT NULL, -- The user this fallback key is for.
+ device_id TEXT NOT NULL, -- The device this fallback key is for.
+ algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+ key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+ key_json TEXT NOT NULL, -- The key as a JSON blob.
+ used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+ CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
index 97c1e6a0c5..c31f9af82a 100644
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', (
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+-- If the server has never backfilled a room then doing `-MIN(...)` will give
+-- a negative result, hence why we do `GREATEST(...)`
SELECT setval('events_backfill_stream_seq', (
- SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+ SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
));
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+ stream_name TEXT NOT NULL,
+ instance_name TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+ instance_id SERIAL PRIMARY KEY,
+ instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index f01cf2fd02..e34fce6281 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
@@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SearchStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5c6168e301..3c1e33819b 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
@@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9c1bf3c289..bc8e78e1f1 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -62,7 +62,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StatsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
@@ -211,6 +211,7 @@ class StatsStore(StateDeltasStore):
* topic
* avatar
* canonical_alias
+ * guest_access
A is_federatable key can also be included with a boolean value.
@@ -235,6 +236,7 @@ class StatsStore(StateDeltasStore):
"topic",
"avatar",
"canonical_alias",
+ "guest_access",
):
field = fields.get(col, sentinel)
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2e95518752..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -35,7 +35,6 @@ what sort order was used:
- topological tokems: "t%d-%d", where the integers map to the topological
and stream ordering columns respectively.
"""
-
import abc
import logging
from collections import namedtuple
@@ -54,7 +53,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import Collection, RoomStreamToken
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
@@ -209,6 +210,55 @@ def _make_generic_sql_bound(
)
+def _filter_results(
+ lower_token: Optional[RoomStreamToken],
+ upper_token: Optional[RoomStreamToken],
+ instance_name: str,
+ topological_ordering: int,
+ stream_ordering: int,
+) -> bool:
+ """Returns True if the event persisted by the given instance at the given
+ topological/stream_ordering falls between the two tokens (taking a None
+ token to mean unbounded).
+
+ Used to filter results from fetching events in the DB against the given
+ tokens. This is necessary to handle the case where the tokens include
+ position maps, which we handle by fetching more than necessary from the DB
+ and then filtering (rather than attempting to construct a complicated SQL
+ query).
+ """
+
+ event_historical_tuple = (
+ topological_ordering,
+ stream_ordering,
+ )
+
+ if lower_token:
+ if lower_token.topological is not None:
+ # If these are historical tokens we compare the `(topological, stream)`
+ # tuples.
+ if event_historical_tuple <= lower_token.as_historical_tuple():
+ return False
+
+ else:
+ # If these are live tokens we compare the stream ordering against the
+ # writers stream position.
+ if stream_ordering <= lower_token.get_stream_pos_for_instance(
+ instance_name
+ ):
+ return False
+
+ if upper_token:
+ if upper_token.topological is not None:
+ if upper_token.as_historical_tuple() < event_historical_tuple:
+ return False
+ else:
+ if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+ return False
+
+ return True
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -259,16 +309,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
return " AND ".join(clauses), args
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_room_max_stream_ordering` and `get_room_min_stream_ordering`
which can be called in the initializer.
"""
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
- super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self._send_federation = hs.should_send_federation()
@@ -307,6 +355,33 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
+ def get_room_max_token(self) -> RoomStreamToken:
+ """Get a `RoomStreamToken` that marks the current maximum persisted
+ position of the events stream. Useful to get a token that represents
+ "now".
+
+ The token returned is a "live" token that may have an instance_map
+ component.
+ """
+
+ min_pos = self._stream_id_gen.get_current_token()
+
+ positions = {}
+ if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+ # The `min_pos` is the minimum position that we know all instances
+ # have finished persisting to, so we only care about instances whose
+ # positions are ahead of that. (Instance positions can be behind the
+ # min position as there are times we can work out that the minimum
+ # position is ahead of the naive minimum across all current
+ # positions. See MultiWriterIdGenerator for details)
+ positions = {
+ i: p
+ for i, p in self._stream_id_gen.get_positions().items()
+ if p > min_pos
+ }
+
+ return RoomStreamToken(None, min_pos, positions)
+
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
- from_id = from_key.stream
- to_id = to_key.stream
-
- has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+ has_changed = self._events_stream_cache.has_entity_changed(
+ room_id, from_key.stream
+ )
if not has_changed:
return [], from_key
def f(txn):
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT event_id, instance_name, topological_ordering, stream_ordering
+ FROM events
+ WHERE
+ room_id = ?
+ AND not outlier
+ AND stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering %s LIMIT ?
+ """ % (
+ order,
+ )
+ txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ][:limit]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
[r.event_id for r in rows], get_prev_content=True
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._set_before_and_after(ret, rows, topo_order=False)
if order.lower() == "desc":
ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = from_key.stream
- to_id = to_key.stream
-
if from_key == to_key:
return []
- if from_id:
+ if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
- user_id, int(from_id)
+ user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id))
-
- rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ sql = """
+ SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+ FROM events AS e, room_memberships AS m
+ WHERE e.event_id = m.event_id
+ AND m.user_id = ?
+ AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ ORDER BY e.stream_ordering ASC
+ """
+ txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+ rows = [
+ _EventDictReturn(event_id, None, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ from_key,
+ to_key,
+ instance_name,
+ topological_ordering,
+ stream_ordering,
+ )
+ ]
return rows
@@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
- ) -> Tuple[int, int, str]:
+ ) -> Optional[Tuple[int, int, str]]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
@@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return "t%d-%d" % (topo, token)
- async def get_stream_id_for_event(self, event_id: str) -> int:
- """The stream ID for an event
- Args:
- event_id: The id of the event to look up a stream token for.
- Raises:
- StoreError if the event wasn't in the database.
- Returns:
- A stream ID.
- """
- return await self.db_pool.runInteraction(
- "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
- )
-
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
) -> int:
@@ -613,26 +705,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
- async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
- """The stream token for an event
- Args:
- event_id: The id of the event to look up a stream token for.
- Raises:
- StoreError if the event wasn't in the database.
- Returns:
- A stream token.
+ async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
+ """Get the persisted position for an event
"""
- stream_id = await self.get_stream_id_for_event(event_id)
- return RoomStreamToken(None, stream_id)
+ row = await self.db_pool.simple_select_one(
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcols=("stream_ordering", "instance_name"),
+ desc="get_position_for_event",
+ )
+
+ return PersistedEventPosition(
+ row["instance_name"] or "master", row["stream_ordering"]
+ )
- async def get_topological_token_for_event(self, event_id: str) -> str:
+ async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A "t%d-%d" topological token.
+ A `RoomStreamToken` topological token.
"""
row = await self.db_pool.simple_select_one(
table="events",
@@ -640,25 +734,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
- return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
+ return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
- async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
- """Get the max topological token in a room before the given stream
+ async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
+ """Gets the topological token in a room after or at the given stream
ordering.
Args:
room_id
stream_key
-
- Returns:
- The maximum topological token.
"""
sql = (
- "SELECT coalesce(max(topological_ordering), 0) FROM events"
- " WHERE room_id = ? AND stream_ordering < ?"
+ "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
+ " WHERE room_id = ? AND stream_ordering >= ?"
)
row = await self.db_pool.execute(
- "get_max_topological_token", None, sql, room_id, stream_key
+ "get_current_topological_token", None, sql, room_id, stream_key
)
return row[0][0] if row else 0
@@ -692,8 +783,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
topo = None
internal = event.internal_metadata
- internal.before = str(RoomStreamToken(topo, stream - 1))
- internal.after = str(RoomStreamToken(topo, stream))
+ internal.before = RoomStreamToken(topo, stream - 1)
+ internal.after = RoomStreamToken(topo, stream)
internal.order = (int(topo) if topo else 0, int(stream))
async def get_events_around(
@@ -980,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
order = "ASC"
+ # The bounds for the stream tokens are complicated by the fact
+ # that we need to handle the instance_map part of the tokens. We do this
+ # by fetching all events between the min stream token and the maximum
+ # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+ # then filtering the results.
+ if from_token.topological is not None:
+ from_bound = (
+ from_token.as_historical_tuple()
+ ) # type: Tuple[Optional[int], int]
+ elif direction == "b":
+ from_bound = (
+ None,
+ from_token.get_max_stream_pos(),
+ )
+ else:
+ from_bound = (
+ None,
+ from_token.stream,
+ )
+
+ to_bound = None # type: Optional[Tuple[Optional[int], int]]
+ if to_token:
+ if to_token.topological is not None:
+ to_bound = to_token.as_historical_tuple()
+ elif direction == "b":
+ to_bound = (
+ None,
+ to_token.stream,
+ )
+ else:
+ to_bound = (
+ None,
+ to_token.get_max_stream_pos(),
+ )
+
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token.as_tuple(),
- to_token=to_token.as_tuple() if to_token else None,
+ from_token=from_bound,
+ to_token=to_bound,
engine=self.database_engine,
)
@@ -994,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds += " AND " + filter_clause
args.extend(filter_args)
- args.append(int(limit))
+ # We fetch more events as we'll filter the result set
+ args.append(int(limit) * 2)
select_keywords = "SELECT"
join_clause = ""
@@ -1016,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
select_keywords += "DISTINCT"
sql = """
- %(select_keywords)s event_id, topological_ordering, stream_ordering
+ %(select_keywords)s
+ event_id, instance_name,
+ topological_ordering, stream_ordering
FROM events
%(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1031,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, args)
- rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+ # Filter the result set.
+ rows = [
+ _EventDictReturn(event_id, topological_ordering, stream_ordering)
+ for event_id, instance_name, topological_ordering, stream_ordering in txn
+ if _filter_results(
+ lower_token=to_token if direction == "b" else from_token,
+ upper_token=from_token if direction == "b" else to_token,
+ instance_name=instance_name,
+ topological_ordering=topological_ordering,
+ stream_ordering=stream_ordering,
+ )
+ ][:limit]
if rows:
topo = rows[-1].topological_ordering
@@ -1096,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
+ @cached()
+ async def get_id_for_instance(self, instance_name: str) -> int:
+ """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+ """
+
+ def _get_id_for_instance_txn(txn):
+ instance_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ allow_none=True,
+ )
+ if instance_id is not None:
+ return instance_id
+
+ # If we don't have an entry upsert one.
+ #
+ # We could do this before the first check, and rely on the cache for
+ # efficiency, but each UPSERT causes the next ID to increment which
+ # can quickly bloat the size of the generated IDs for new instances.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ values={},
+ )
+
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="instance_map",
+ keyvalues={"instance_name": instance_name},
+ retcol="instance_id",
+ )
+
+ return await self.db_pool.runInteraction(
+ "get_id_for_instance", _get_id_for_instance_txn
+ )
+
+ @cached()
+ async def get_name_from_instance_id(self, instance_id: int) -> str:
+ """Get the instance name from an ID previously returned by
+ `get_id_for_instance`.
+ """
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="instance_map",
+ keyvalues={"instance_id": instance_id},
+ retcol="instance_name",
+ desc="get_name_from_instance_id",
+ )
+
class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 091367006e..7d46090267 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -43,14 +43,32 @@ _UpdateTransactionRow = namedtuple(
SENTINEL = object()
-class TransactionStore(SQLBaseStore):
+class TransactionWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+
+ @wrap_as_background_process("cleanup_transactions")
+ async def _cleanup_transactions(self) -> None:
+ now = self._clock.time_msec()
+ month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+ def _cleanup_transactions_txn(txn):
+ txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+ await self.db_pool.runInteraction(
+ "_cleanup_transactions", _cleanup_transactions_txn
+ )
+
+
+class TransactionStore(TransactionWorkerStore):
"""A collection of queries for handling PDUs.
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(TransactionStore, self).__init__(database, db_conn, hs)
-
- self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
+ super().__init__(database, db_conn, hs)
self._destination_retry_cache = ExpiringCache(
cache_name="get_destination_retry_timings",
@@ -218,6 +236,7 @@ class TransactionStore(SQLBaseStore):
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
@@ -249,7 +268,11 @@ class TransactionStore(SQLBaseStore):
"retry_interval": retry_interval,
},
)
- elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
+ elif (
+ retry_interval == 0
+ or prev_row["retry_interval"] is None
+ or prev_row["retry_interval"] < retry_interval
+ ):
self.db_pool.simple_update_one_txn(
txn,
"destinations",
@@ -261,22 +284,6 @@ class TransactionStore(SQLBaseStore):
},
)
- def _start_cleanup_transactions(self):
- return run_as_background_process(
- "cleanup_transactions", self._cleanup_transactions
- )
-
- async def _cleanup_transactions(self) -> None:
- now = self._clock.time_msec()
- month_ago = now - 30 * 24 * 60 * 60 * 1000
-
- def _cleanup_transactions_txn(txn):
- txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
-
- await self.db_pool.runInteraction(
- "_cleanup_transactions", _cleanup_transactions_txn
- )
-
async def store_destination_rooms_entries(
self, destinations: Iterable[str], room_id: str, stream_ordering: int,
) -> None:
@@ -397,7 +404,7 @@ class TransactionStore(SQLBaseStore):
@staticmethod
def _get_catch_up_room_event_ids_txn(
- txn, destination: str, last_successful_stream_ordering: int,
+ txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
) -> List[str]:
q = """
SELECT event_id FROM destination_rooms
@@ -412,3 +419,60 @@ class TransactionStore(SQLBaseStore):
)
event_ids = [row[0] for row in txn]
return event_ids
+
+ async def get_catch_up_outstanding_destinations(
+ self, after_destination: Optional[str]
+ ) -> List[str]:
+ """
+ Gets at most 25 destinations which have outstanding PDUs to be caught up,
+ and are not being backed off from
+ Args:
+ after_destination:
+ If provided, all destinations must be lexicographically greater
+ than this one.
+
+ Returns:
+ list of up to 25 destinations with outstanding catch-up.
+ These are the lexicographically first destinations which are
+ lexicographically greater than after_destination (if provided).
+ """
+ time = self.hs.get_clock().time_msec()
+
+ return await self.db_pool.runInteraction(
+ "get_catch_up_outstanding_destinations",
+ self._get_catch_up_outstanding_destinations_txn,
+ time,
+ after_destination,
+ )
+
+ @staticmethod
+ def _get_catch_up_outstanding_destinations_txn(
+ txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
+ ) -> List[str]:
+ q = """
+ SELECT destination FROM destinations
+ WHERE destination IN (
+ SELECT destination FROM destination_rooms
+ WHERE destination_rooms.stream_ordering >
+ destinations.last_successful_stream_ordering
+ )
+ AND destination > ?
+ AND (
+ retry_last_ts IS NULL OR
+ retry_last_ts + retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
+ """
+ txn.execute(
+ q,
+ (
+ # everything is lexicographically greater than "" so this gives
+ # us the first batch of up to 25.
+ after_destination or "",
+ now_time_ms,
+ ),
+ )
+
+ destinations = [row[0] for row in txn]
+ return destinations
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore):
)
return [(row["user_agent"], row["ip"]) for row in rows]
-
-class UIAuthStore(UIAuthWorkerStore):
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore):
iterable=session_ids,
keyvalues={},
)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+ pass
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f2f9a5799a..5a390ff2f6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
- super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
- super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
return
# They are there, delete them.
- self.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, "erased_users", keyvalues={"user_id": user_id}
)
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 139085b672..acb24e33af 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e924f1ca3b..0e31cc811a 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
@@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
)
+ self._state_group_seq_gen.check_consistency(
+ db_conn, table="state_groups", id_column="id"
+ )
@cached(max_entries=10000, iterable=True)
async def get_state_group_delta(self, state_group):
@@ -205,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
- ) -> Dict[int, StateMap[str]]:
+ ) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 908cbc79e3..d6d632dc10 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -97,3 +97,20 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
"""Gets a string giving the server version. For example: '3.22.0'
"""
...
+
+ @abc.abstractmethod
+ def in_transaction(self, conn: Connection) -> bool:
+ """Whether the connection is currently in a transaction.
+ """
+ ...
+
+ @abc.abstractmethod
+ def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+ """Attempt to set the connections autocommit mode.
+
+ When True queries are run outside of transactions.
+
+ Note: This has no effect on SQLite3, so callers still need to
+ commit/rollback the connections.
+ """
+ ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index ff39281f85..7719ac32f7 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,7 +15,8 @@
import logging
-from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
+from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup
+from synapse.storage.types import Connection
logger = logging.getLogger(__name__)
@@ -119,6 +120,7 @@ class PostgresEngine(BaseDatabaseEngine):
cursor.execute("SET synchronous_commit TO OFF")
cursor.close()
+ db_conn.commit()
@property
def can_native_upsert(self):
@@ -171,3 +173,9 @@ class PostgresEngine(BaseDatabaseEngine):
return "%i.%i" % (numver / 10000, numver % 10000)
else:
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
+
+ def in_transaction(self, conn: Connection) -> bool:
+ return conn.status != self.module.extensions.STATUS_READY # type: ignore
+
+ def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+ return conn.set_session(autocommit=autocommit) # type: ignore
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 8a0f8c89d1..5db0f0b520 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -17,6 +17,7 @@ import threading
import typing
from synapse.storage.engines import BaseDatabaseEngine
+from synapse.storage.types import Connection
if typing.TYPE_CHECKING:
import sqlite3 # noqa: F401
@@ -86,6 +87,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
db_conn.create_function("rank", 1, _rank)
db_conn.execute("PRAGMA foreign_keys = ON;")
+ db_conn.commit()
def is_deadlock(self, error):
return False
@@ -105,6 +107,14 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
"""
return "%i.%i.%i" % self.module.sqlite_version_info
+ def in_transaction(self, conn: Connection) -> bool:
+ return conn.in_transaction # type: ignore
+
+ def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+ # Twisted doesn't let us set attributes on the connections, so we can't
+ # set the connection to autocommit mode.
+ pass
+
# Following functions taken from: https://github.com/coleifer/peewee
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index d89f6ed128..4d2d88d1f0 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -190,15 +190,16 @@ class EventsPersistenceStorage:
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
async def persist_events(
self,
- events_and_contexts: List[Tuple[EventBase, EventContext]],
+ events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ) -> int:
+ ) -> RoomStreamToken:
"""
Write events to the database
Args:
@@ -228,11 +229,11 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True)
)
- return self.main_store.get_current_events_token()
+ return self.main_store.get_room_max_token()
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ) -> Tuple[int, int]:
+ ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
"""
Returns:
The stream ordering of `event`, and the stream ordering of the
@@ -246,8 +247,12 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred)
- max_persisted_id = self.main_store.get_current_events_token()
- return (event.internal_metadata.stream_ordering, max_persisted_id)
+ event_stream_id = event.internal_metadata.stream_ordering
+ # stream ordering should have been assigned by now
+ assert event_stream_id
+
+ pos = PersistedEventPosition(self._instance_name, event_stream_id)
+ return pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 77de025069..9e3dfe4805 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import imp
import logging
import os
@@ -24,9 +23,10 @@ from typing import Optional, TextIO
import attr
from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -64,7 +64,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
def prepare_database(
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ["main", "state"],
@@ -86,7 +86,7 @@ def prepare_database(
"""
try:
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="prepare_database")
# sqlite does not automatically start transactions for DDL / SELECT statements,
# so we start one before running anything. This ensures that any upgrades
@@ -255,9 +255,7 @@ def _setup_new_database(cur, database_engine, databases):
executescript(cur, entry.absolute_path)
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
- ),
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
(max_current_ver, False),
)
@@ -483,17 +481,13 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
- ),
+ "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
(v, relative_path),
)
cur.execute("DELETE FROM schema_version")
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
- ),
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
(v, True),
)
@@ -529,10 +523,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
schemas to be applied
"""
cur.execute(
- database_engine.convert_param_style(
- "SELECT file FROM applied_module_schemas WHERE module_name = ?"
- ),
- (modname,),
+ "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
)
applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
@@ -550,9 +541,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
- database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
- ),
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
(modname, name),
)
@@ -624,9 +613,7 @@ def _get_or_create_schema_state(txn, database_engine):
if current_version:
txn.execute(
- database_engine.convert_param_style(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?"
- ),
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
)
applied_deltas = [d for d, in txn]
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
)
GetRoomsForUserWithStreamOrdering = namedtuple(
- "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+ "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 8f68d968f0..08a69f2f96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,7 +20,7 @@ import attr
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
logger = logging.getLogger(__name__)
@@ -349,7 +349,7 @@ class StateGroupStorage:
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
- ) -> Dict[int, StateMap[str]]:
+ ) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
@@ -532,7 +532,7 @@ class StateGroupStorage:
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
- ) -> Awaitable[Dict[int, StateMap[str]]]:
+ ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 2d2b560e74..970bb1b9da 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -61,3 +61,9 @@ class Connection(Protocol):
def rollback(self, *args, **kwargs) -> None:
...
+
+ def __enter__(self) -> "Connection":
+ ...
+
+ def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..d7e40aaa8b 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
logger = logging.getLogger(__name__)
@@ -53,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1):
"""
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
- cur = db_conn.cursor()
+ cur = db_conn.cursor(txn_name="_load_current_id")
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
@@ -86,7 +88,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +103,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +115,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +123,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +142,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +151,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +186,16 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
+ stream_name: A name for the stream.
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ writers: A list of known writers to use to populate current positions
+ on startup. Can be empty if nothing uses `get_current_token` or
+ `get_positions` (e.g. caches stream).
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
@@ -198,16 +204,20 @@ class MultiWriterIdGenerator:
self,
db_conn,
db: DatabasePool,
+ stream_name: str,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
sequence_name: str,
+ writers: List[str],
positive: bool = True,
):
self._db = db
+ self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
+ self._writers = writers
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
@@ -216,9 +226,7 @@ class MultiWriterIdGenerator:
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
- self._current_positions = self._load_current_ids(
- db_conn, table, instance_column, id_column
- )
+ self._current_positions = {} # type: Dict[str, int]
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
@@ -251,30 +259,98 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+ # We check that the table and sequence haven't diverged.
+ self._sequence_gen.check_consistency(
+ db_conn, table=table, id_column=id_column, positive=positive
+ )
+
+ # This goes and fills out the above state from the database.
+ self._load_current_ids(db_conn, table, instance_column, id_column)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
- ) -> Dict[str, int]:
- # If positive stream aggregate via MAX. For negative stream use MIN
- # *and* negate the result to get a positive number.
- sql = """
- SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
- GROUP BY %(instance)s
- """ % {
- "instance": instance_column,
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
+ ):
+ cur = db_conn.cursor(txn_name="_load_current_ids")
+
+ # Load the current positions of all writers for the stream.
+ if self._writers:
+ # We delete any stale entries in the positions table. This is
+ # important if we add back a writer after a long time; we want to
+ # consider that a "new" writer, rather than using the old stale
+ # entry here.
+ sql = """
+ DELETE FROM stream_positions
+ WHERE
+ stream_name = ?
+ AND instance_name != ALL(?)
+ """
+ cur.execute(sql, (self._stream_name, self._writers))
+
+ sql = """
+ SELECT instance_name, stream_id FROM stream_positions
+ WHERE stream_name = ?
+ """
+ cur.execute(sql, (self._stream_name,))
+
+ self._current_positions = {
+ instance: stream_id * self._return_factor
+ for instance, stream_id in cur
+ if instance in self._writers
+ }
- cur = db_conn.cursor()
- cur.execute(sql)
+ # We set the `_persisted_upto_position` to be the minimum of all current
+ # positions. If empty we use the max stream ID from the DB table.
+ min_stream_id = min(self._current_positions.values(), default=None)
+
+ if min_stream_id is None:
+ # We add a GREATEST here to ensure that the result is always
+ # positive. (This can be a problem for e.g. backfill streams where
+ # the server has never backfilled).
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+ self._persisted_upto_position = stream_id
+ else:
+ # If we have a min_stream_id then we pull out everything greater
+ # than it from the DB so that we can prefill
+ # `_known_persisted_positions` and get a more accurate
+ # `_persisted_upto_position`.
+ #
+ # We also check if any of the later rows are from this instance, in
+ # which case we use that for this instance's current position. This
+ # is to handle the case where we didn't finish persisting to the
+ # stream positions table before restart (or the stream position
+ # table otherwise got out of date).
+
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ cur.execute(sql, (min_stream_id * self._return_factor,))
- # `cur` is an iterable over returned rows, which are 2-tuples.
- current_positions = dict(cur)
+ self._persisted_upto_position = min_stream_id
- cur.close()
+ with self._lock:
+ for (instance, stream_id,) in cur:
+ stream_id = self._return_factor * stream_id
+ self._add_persisted_position(stream_id)
- return current_positions
+ if instance == self._instance_name:
+ self._current_positions[instance] = stream_id
+
+ cur.close()
def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
@@ -282,59 +358,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
- self._unfinished_ids.add(next_id)
-
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
+ return _MultiWriterCtxManager(self)
- return manager()
-
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
-
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -352,6 +392,21 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persited row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
@@ -363,7 +418,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
- new_cur = None
+ new_cur = None # type: Optional[int]
if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the
@@ -408,11 +463,22 @@ class MultiWriterIdGenerator:
"""Returns the position of the given writer.
"""
+ # If we don't have an entry for the given instance name, we assume it's a
+ # new writer.
+ #
+ # For new writers we assume their initial position to be the current
+ # persisted up to position. This stops Synapse from doing a full table
+ # scan when a new writer announces itself over replication.
with self._lock:
- return self._return_factor * self._current_positions.get(instance_name, 0)
+ return self._return_factor * self._current_positions.get(
+ instance_name, self._persisted_upto_position
+ )
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
+
+ Note that this won't necessarily include all configured writers if some
+ writers haven't written anything yet.
"""
with self._lock:
@@ -482,3 +548,104 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+ def _update_stream_positions_table_txn(self, txn: Cursor):
+ """Update the `stream_positions` table with newly persisted position.
+ """
+
+ if not self._writers:
+ return
+
+ # We upsert the value, ensuring on conflict that we always increase the
+ # value (or decrease if stream goes backwards).
+ sql = """
+ INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+ VALUES (?, ?, ?)
+ ON CONFLICT (stream_name, instance_name)
+ DO UPDATE SET
+ stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+ """ % {
+ "agg": "GREATEST" if self._positive else "LEAST",
+ }
+
+ pos = (self.get_current_token_for_writer(self._instance_name),)
+ txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ # It's safe to run this in autocommit mode as fetching values from a
+ # sequence ignores transaction semantics anyway.
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ db_autocommit=True,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ #
+ # We do this in autocommit mode as a) the upsert works correctly outside
+ # transactions and b) reduces the amount of time the rows are locked
+ # for. If we don't do this then we'll often hit serialization errors due
+ # to the fact we default to REPEATABLE READ isolation levels.
+ if self.id_gen._writers:
+ await self.id_gen._db.runInteraction(
+ "MultiWriterIdGenerator._update_table",
+ self.id_gen._update_stream_positions_table_txn,
+ db_autocommit=True,
+ )
+
+ return False
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ffc1894748..ff2d038ad2 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -13,11 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import logging
import threading
from typing import Callable, List, Optional
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.engines import (
+ BaseDatabaseEngine,
+ IncorrectDatabaseSetup,
+ PostgresEngine,
+)
+from synapse.storage.types import Connection, Cursor
+
+logger = logging.getLogger(__name__)
+
+
+_INCONSISTENT_SEQUENCE_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated
+table '%(table)s'. This can happen if Synapse has been downgraded and
+then upgraded again, or due to a bad migration.
+
+To fix this error, shut down Synapse (including any and all workers)
+and run the following SQL:
+
+ SELECT setval('%(seq)s', (
+ %(max_id_sql)s
+ ));
+
+See docs/postgres.md for more information.
+"""
class SequenceGenerator(metaclass=abc.ABCMeta):
@@ -28,6 +52,23 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""Gets the next ID in the sequence"""
...
+ @abc.abstractmethod
+ def check_consistency(
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
+ ):
+ """Should be called during start up to test that the current value of
+ the sequence is greater than or equal to the maximum ID in the table.
+
+ This is to handle various cases where the sequence value can get out
+ of sync with the table, e.g. if Synapse gets rolled back to a previous
+ version and the rolled forwards again.
+ """
+ ...
+
class PostgresSequenceGenerator(SequenceGenerator):
"""An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -45,6 +86,54 @@ class PostgresSequenceGenerator(SequenceGenerator):
)
return [i for (i,) in txn]
+ def check_consistency(
+ self,
+ db_conn: LoggingDatabaseConnection,
+ table: str,
+ id_column: str,
+ positive: bool = True,
+ ):
+ txn = db_conn.cursor(txn_name="sequence.check_consistency")
+
+ # First we get the current max ID from the table.
+ table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if positive else "-MIN",
+ }
+
+ txn.execute(table_sql)
+ row = txn.fetchone()
+ if not row:
+ # Table is empty, so nothing to do.
+ txn.close()
+ return
+
+ # Now we fetch the current value from the sequence and compare with the
+ # above.
+ max_stream_id = row[0]
+ txn.execute(
+ "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
+ )
+ last_value, is_called = txn.fetchone()
+ txn.close()
+
+ # If `is_called` is False then `last_value` is actually the value that
+ # will be generated next, so we decrement to get the true "last value".
+ if not is_called:
+ last_value -= 1
+
+ if max_stream_id > last_value:
+ logger.warning(
+ "Postgres sequence %s is behind table %s: %d < %d",
+ last_value,
+ max_stream_id,
+ )
+ raise IncorrectDatabaseSetup(
+ _INCONSISTENT_SEQUENCE_ERROR
+ % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
+ )
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -81,6 +170,12 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def check_consistency(
+ self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ ):
+ # There is nothing to do for in memory sequences
+ pass
+
def build_sequence_generator(
database_engine: BaseDatabaseEngine,
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 0bdf846edf..fdda21d165 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Optional
@@ -21,6 +20,7 @@ import attr
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
+from synapse.storage.databases.main import DataStore
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@@ -39,8 +39,9 @@ class PaginationConfig:
limit = attr.ib(type=Optional[int])
@classmethod
- def from_request(
+ async def from_request(
cls,
+ store: "DataStore",
request: SynapseRequest,
raise_invalid_params: bool = True,
default_limit: Optional[int] = None,
@@ -54,13 +55,13 @@ class PaginationConfig:
if from_tok == "END":
from_tok = None # For backwards compat.
elif from_tok:
- from_tok = StreamToken.from_string(from_tok)
+ from_tok = await StreamToken.from_string(store, from_tok)
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
try:
if to_tok:
- to_tok = StreamToken.from_string(to_tok)
+ to_tok = await StreamToken.from_string(store, to_tok)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")
diff --git a/synapse/types.py b/synapse/types.py
index dc09448bdc..5bde67cc07 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,18 @@ import re
import string
import sys
from collections import namedtuple
-from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+)
import attr
from signedjson.key import decode_verify_key_bytes
@@ -26,11 +37,14 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
+if TYPE_CHECKING:
+ from synapse.storage.databases.main import DataStore
+
# define a version of typing.Collection that works on python 3.5
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Container, Iterable, Sized
+ from typing import Container, Sized
T_co = TypeVar("T_co", covariant=True)
@@ -165,7 +179,9 @@ def get_localpart_from_id(string):
DS = TypeVar("DS", bound="DomainSpecificString")
-class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
+class DomainSpecificString(
+ namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta
+):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -175,8 +191,6 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
'domain' : The domain part of the name
"""
- __metaclass__ = abc.ABCMeta
-
SIGIL = abc.abstractproperty() # type: str # type: ignore
# Deny iteration because it will bite you if you try to create a singleton
@@ -362,7 +376,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, cmp=False)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
@@ -384,6 +398,31 @@ class RoomStreamToken:
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
+
+ There is also a third mode for live tokens where the token starts with "m",
+ which is sometimes used when using sharded event persisters. In this case
+ the events stream is considered to be a set of streams (one for each writer)
+ and the token encodes the vector clock of positions of each writer in their
+ respective streams.
+
+ The format of the token in such case is an initial integer min position,
+ followed by the mapping of instance ID to position separated by '.' and '~':
+
+ m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ...
+
+ The `min_pos` corresponds to the minimum position all writers have persisted
+ up to, and then only writers that are ahead of that position need to be
+ encoded. An example token is:
+
+ m56~2.58~3.59
+
+ Which corresponds to a set of three (or more writers) where instances 2 and
+ 3 (these are instance IDs that can be looked up in the DB to fetch the more
+ commonly used instance names) are at positions 58 and 59 respectively, and
+ all other instances are at position 56.
+
+ Note: The `RoomStreamToken` cannot have both a topological part and an
+ instance map.
"""
topological = attr.ib(
@@ -392,14 +431,47 @@ class RoomStreamToken:
)
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+ instance_map = attr.ib(
+ type=Dict[str, int],
+ factory=dict,
+ validator=attr.validators.deep_mapping(
+ key_validator=attr.validators.instance_of(str),
+ value_validator=attr.validators.instance_of(int),
+ mapping_validator=attr.validators.instance_of(dict),
+ ),
+ )
+
+ def __attrs_post_init__(self):
+ """Validates that both `topological` and `instance_map` aren't set.
+ """
+
+ if self.instance_map and self.topological:
+ raise ValueError(
+ "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
+ )
+
@classmethod
- def parse(cls, string: str) -> "RoomStreamToken":
+ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
+ if string[0] == "m":
+ parts = string[1:].split("~")
+ stream = int(parts[0])
+
+ instance_map = {}
+ for part in parts[1:]:
+ key, value = part.split(".")
+ instance_id = int(key)
+ pos = int(value)
+
+ instance_name = await store.get_name_from_instance_id(instance_id)
+ instance_map[instance_name] = pos
+
+ return cls(topological=None, stream=stream, instance_map=instance_map,)
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -413,12 +485,71 @@ class RoomStreamToken:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
- def as_tuple(self) -> Tuple[Optional[int], int]:
+ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
+ """Return a new token such that if an event is after both this token and
+ the other token, then its after the returned token too.
+ """
+
+ if self.topological or other.topological:
+ raise Exception("Can't advance topological tokens")
+
+ max_stream = max(self.stream, other.stream)
+
+ instance_map = {
+ instance: max(
+ self.instance_map.get(instance, self.stream),
+ other.instance_map.get(instance, other.stream),
+ )
+ for instance in set(self.instance_map).union(other.instance_map)
+ }
+
+ return RoomStreamToken(None, max_stream, instance_map)
+
+ def as_historical_tuple(self) -> Tuple[int, int]:
+ """Returns a tuple of `(topological, stream)` for historical tokens.
+
+ Raises if not an historical token (i.e. doesn't have a topological part).
+ """
+ if self.topological is None:
+ raise Exception(
+ "Cannot call `RoomStreamToken.as_historical_tuple` on live token"
+ )
+
return (self.topological, self.stream)
- def __str__(self) -> str:
+ def get_stream_pos_for_instance(self, instance_name: str) -> int:
+ """Get the stream position that the given writer was at at this token.
+
+ This only makes sense for "live" tokens that may have a vector clock
+ component, and so asserts that this is a "live" token.
+ """
+ assert self.topological is None
+
+ # If we don't have an entry for the instance we can assume that it was
+ # at `self.stream`.
+ return self.instance_map.get(instance_name, self.stream)
+
+ def get_max_stream_pos(self) -> int:
+ """Get the maximum stream position referenced in this token.
+
+ The corresponding "min" position is, by definition just `self.stream`.
+
+ This is used to handle tokens that have non-empty `instance_map`, and so
+ reference stream positions after the `self.stream` position.
+ """
+ return max(self.instance_map.values(), default=self.stream)
+
+ async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
+ elif self.instance_map:
+ entries = []
+ for name, pos in self.instance_map.items():
+ instance_id = await store.get_id_for_instance(name)
+ entries.append("{}.{}".format(instance_id, pos))
+
+ encoded_map = "~".join(entries)
+ return "m{}~{}".format(self.stream, encoded_map)
else:
return "s%d" % (self.stream,)
@@ -441,48 +572,51 @@ class StreamToken:
START = None # type: StreamToken
@classmethod
- def from_string(cls, string):
+ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
try:
keys = string.split(cls._SEPARATOR)
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
- return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
+ return cls(
+ await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+ )
except Exception:
raise SynapseError(400, "Invalid Token")
- def to_string(self):
- return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
+ async def to_string(self, store: "DataStore") -> str:
+ return self._SEPARATOR.join(
+ [
+ await self.room_key.to_string(store),
+ str(self.presence_key),
+ str(self.typing_key),
+ str(self.receipt_key),
+ str(self.account_data_key),
+ str(self.push_rules_key),
+ str(self.to_device_key),
+ str(self.device_list_key),
+ str(self.groups_key),
+ ]
+ )
@property
def room_stream_id(self):
return self.room_key.stream
- def is_after(self, other):
- """Does this token contain events that the other doesn't?"""
- return (
- (other.room_stream_id < self.room_stream_id)
- or (int(other.presence_key) < int(self.presence_key))
- or (int(other.typing_key) < int(self.typing_key))
- or (int(other.receipt_key) < int(self.receipt_key))
- or (int(other.account_data_key) < int(self.account_data_key))
- or (int(other.push_rules_key) < int(self.push_rules_key))
- or (int(other.to_device_key) < int(self.to_device_key))
- or (int(other.device_list_key) < int(self.device_list_key))
- or (int(other.groups_key) < int(self.groups_key))
- )
-
def copy_and_advance(self, key, new_value) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
- new_token = self.copy_and_replace(key, new_value)
if key == "room_key":
- new_id = new_token.room_stream_id
- old_id = self.room_stream_id
- else:
- new_id = int(getattr(new_token, key))
- old_id = int(getattr(self, key))
+ new_token = self.copy_and_replace(
+ "room_key", self.room_key.copy_and_advance(new_value)
+ )
+ return new_token
+
+ new_token = self.copy_and_replace(key, new_value)
+ new_id = int(getattr(new_token, key))
+ old_id = int(getattr(self, key))
+
if old_id < new_id:
return new_token
else:
@@ -492,7 +626,34 @@ class StreamToken:
return attr.evolve(self, **{key: new_value})
-StreamToken.START = StreamToken.from_string("s0_0")
+StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
+
+
+@attr.s(slots=True, frozen=True)
+class PersistedEventPosition:
+ """Position of a newly persisted event with instance that persisted it.
+
+ This can be used to test whether the event is persisted before or after a
+ RoomStreamToken.
+ """
+
+ instance_name = attr.ib(type=str)
+ stream = attr.ib(type=int)
+
+ def persisted_after(self, token: RoomStreamToken) -> bool:
+ return token.get_stream_pos_for_instance(self.instance_name) < self.stream
+
+ def to_room_stream_token(self) -> RoomStreamToken:
+ """Converts the position to a room stream token such that events
+ persisted in the same room after this position will be after the
+ returned `RoomStreamToken`.
+
+ Note: no guarentees are made about ordering w.r.t. events in other
+ rooms.
+ """
+ # Doing the naive thing satisfies the desired properties described in
+ # the docstring.
+ return RoomStreamToken(None, self.stream)
class ThirdPartyInstanceID(
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 67ce9a5f39..382f0cf3f0 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -449,18 +449,8 @@ class ReadWriteLock:
R = TypeVar("R")
-def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
- if isinstance(value, failure.Failure):
- value.trap(CancelledError)
- raise defer.TimeoutError(timeout, "Deferred")
- return value
-
-
def timeout_deferred(
- deferred: defer.Deferred,
- timeout: float,
- reactor: IReactorTime,
- on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
+ deferred: defer.Deferred, timeout: float, reactor: IReactorTime,
) -> defer.Deferred:
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
@@ -469,27 +459,21 @@ def timeout_deferred(
(See https://twistedmatrix.com/trac/ticket/9534)
- NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
+ NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred.
+
+ NOTE: the TimeoutError raised by the resultant deferred is
+ twisted.internet.defer.TimeoutError, which is *different* to the built-in
+ TimeoutError, as well as various other TimeoutErrors you might have imported.
Args:
deferred: The Deferred to potentially timeout.
timeout: Timeout in seconds
reactor: The twisted reactor to use
- on_timeout_cancel: A callable which is called immediately
- after the deferred times out, and not if this deferred is
- otherwise cancelled before the timeout.
- It takes an arbitrary value, which is the value of the deferred at
- that exact point in time (probably a CancelledError Failure), and
- the timeout.
-
- The default callable (if none is provided) will translate a
- CancelledError Failure into a defer.TimeoutError.
Returns:
- A new Deferred.
+ A new Deferred, which will errback with defer.TimeoutError on timeout.
"""
-
new_d = defer.Deferred()
timed_out = [False]
@@ -502,18 +486,23 @@ def timeout_deferred(
except: # noqa: E722, if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
+ # the cancel() call should have set off a chain of errbacks which
+ # will have errbacked new_d, but in case it hasn't, errback it now.
+
if not new_d.called:
- new_d.errback(defer.TimeoutError(timeout, "Deferred"))
+ new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,)))
delayed_call = reactor.callLater(timeout, time_it_out)
- def convert_cancelled(value):
- if timed_out[0]:
- to_call = on_timeout_cancel or _cancelled_to_timed_out_error
- return to_call(value, timeout)
+ def convert_cancelled(value: failure.Failure):
+ # if the orgininal deferred was cancelled, and our timeout has fired, then
+ # the reason it was cancelled was due to our timeout. Turn the CancelledError
+ # into a TimeoutError.
+ if timed_out[0] and value.check(CancelledError):
+ raise defer.TimeoutError("Timed out after %gs" % (timeout,))
return value
- deferred.addBoth(convert_cancelled)
+ deferred.addErrback(convert_cancelled)
def cancel_timeout(result):
# stop the pending call to cancel the deferred if it's been fired
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 631654f297..da24ba0470 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -94,7 +94,7 @@ class SynapseManhole(ColoredManhole):
"""Overrides connectionMade to create our own ManholeInterpreter"""
def connectionMade(self):
- super(SynapseManhole, self).connectionMade()
+ super().connectionMade()
# replace the manhole interpreter with our own impl
self.interpreter = SynapseManholeInterpreter(self, self.namespace)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 6e57c1ee72..ffdea0de8d 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -19,7 +19,11 @@ from typing import Any, Callable, Optional, TypeVar, cast
from prometheus_client import Counter
-from synapse.logging.context import LoggingContext, current_context
+from synapse.logging.context import (
+ ContextResourceUsage,
+ LoggingContext,
+ current_context,
+)
from synapse.metrics import InFlightGauge
logger = logging.getLogger(__name__)
@@ -104,27 +108,27 @@ class Measure:
def __init__(self, clock, name):
self.clock = clock
self.name = name
- self._logging_context = None
+ parent_context = current_context()
+ self._logging_context = LoggingContext(
+ "Measure[%s]" % (self.name,), parent_context
+ )
self.start = None
- def __enter__(self):
- if self._logging_context:
+ def __enter__(self) -> "Measure":
+ if self.start is not None:
raise RuntimeError("Measure() objects cannot be re-used")
self.start = self.clock.time()
- parent_context = current_context()
- self._logging_context = LoggingContext(
- "Measure[%s]" % (self.name,), parent_context
- )
self._logging_context.__enter__()
in_flight.register((self.name,), self._update_in_flight)
+ return self
def __exit__(self, exc_type, exc_val, exc_tb):
- if not self._logging_context:
+ if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
duration = self.clock.time() - self.start
- usage = self._logging_context.get_resource_usage()
+ usage = self.get_resource_usage()
in_flight.unregister((self.name,), self._update_in_flight)
self._logging_context.__exit__(exc_type, exc_val, exc_tb)
@@ -140,6 +144,13 @@ class Measure:
except ValueError:
logger.warning("Failed to save metrics! Usage: %s", usage)
+ def get_resource_usage(self) -> ContextResourceUsage:
+ """Get the resources used within this Measure block
+
+ If the Measure block is still active, returns the resource usage so far.
+ """
+ return self._logging_context.get_resource_usage()
+
def _update_in_flight(self, metrics):
"""Gets called when processing in flight metrics
"""
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index bb62db4637..94b59afb38 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -36,7 +36,7 @@ def load_module(provider):
try:
provider_config = provider_class.parse_config(provider.get("config"))
except Exception as e:
- raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
+ raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
return provider_class, provider_config
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 54c046b6e1..72574d3af2 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import print_function
-
import functools
import sys
from typing import Any, Callable, List
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 79869aaa44..a5cc9d0551 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -45,7 +45,7 @@ class NotRetryingDestination(Exception):
"""
msg = "Not retrying server %s." % (destination,)
- super(NotRetryingDestination, self).__init__(msg)
+ super().__init__(msg)
self.retry_last_ts = retry_last_ts
self.retry_interval = retry_interval
diff --git a/synapse/visibility.py b/synapse/visibility.py
index e3da7744d2..527365498e 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -16,7 +16,7 @@
import logging
import operator
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter
@@ -77,15 +77,14 @@ async def filter_events_for_client(
)
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id
+ AccountDataTypes.IGNORED_USER_LIST, user_id
)
- # FIXME: This will explode if people upload something incorrect.
- ignore_list = frozenset(
- ignore_dict_content.get("ignored_users", {}).keys()
- if ignore_dict_content
- else []
- )
+ ignore_list = frozenset()
+ if ignore_dict_content:
+ ignored_users_dict = ignore_dict_content.get("ignored_users", {})
+ if isinstance(ignored_users_dict, dict):
+ ignore_list = frozenset(ignored_users_dict.keys())
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
diff --git a/sytest-blacklist b/sytest-blacklist
index b563448016..de9986357b 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -34,9 +34,6 @@ New federated private chats get full presence information (SYN-115)
# this requirement from the spec
Inbound federation of state requires event_id as a mandatory paramater
-# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands
-Can upload self-signing keys
-
# Blacklisted until MSC2753 is implemented
Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2e6e7abf1f..8ff1460c0d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
@@ -33,7 +34,6 @@ from synapse.crypto.keyring import (
)
from synapse.logging.context import (
LoggingContext,
- PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
@@ -41,6 +41,7 @@ from synapse.storage.keys import FetchKeyResult
from tests import unittest
from tests.test_utils import make_awaitable
+from tests.unittest import logcontext_clean
class MockPerspectiveServer:
@@ -67,55 +68,42 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key)
+@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- self.mock_perspective_server = MockPerspectiveServer()
- self.http_client = Mock()
-
- config = self.default_config()
- config["trusted_key_servers"] = [
- {
- "server_name": self.mock_perspective_server.server_name,
- "verify_keys": self.mock_perspective_server.get_verify_keys(),
- }
- ]
-
- return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
- )
-
- def check_context(self, _, expected):
+ def check_context(self, val, expected):
self.assertEquals(getattr(current_context(), "request", None), expected)
+ return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- key1 = signedjson.key.generate_signing_key(1)
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock()
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
- kr = keyring.Keyring(self.hs)
+ # a signed object that we are going to try to validate
+ key1 = signedjson.key.generate_signing_key(1)
json1 = {}
signedjson.sign.sign_json(json1, "server10", key1)
- persp_resp = {
- "server_keys": [
- self.mock_perspective_server.get_signed_key(
- "server10", signedjson.key.get_verify_key(key1)
- )
- ]
- }
- persp_deferred = defer.Deferred()
+ # start off a first set of lookups. We make the mock fetcher block until this
+ # deferred completes.
+ first_lookup_deferred = Deferred()
+
+ async def first_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_11")
+ self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
- async def get_perspectives(**kwargs):
- self.assertEquals(current_context().request, "11")
- with PreserveLoggingContext():
- await persp_deferred
- return persp_resp
+ await make_deferred_yieldable(first_lookup_deferred)
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
- self.http_client.post_json.side_effect = get_perspectives
+ mock_fetcher.get_keys.side_effect = first_lookup_fetch
- # start off a first set of lookups
- @defer.inlineCallbacks
- def first_lookup():
- with LoggingContext("11") as context_11:
- context_11.request = "11"
+ async def first_lookup():
+ with LoggingContext("context_11") as context_11:
+ context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +112,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
- yield res_deferreds[1]
+ await res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
@@ -132,45 +120,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds[0])
+ await make_deferred_yieldable(res_deferreds[0])
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
+ d0 = ensureDeferred(first_lookup())
- d0 = first_lookup()
-
- # wait a tick for it to send the request to the perspectives server
- # (it first tries the datastore)
- self.pump()
- self.http_client.post_json.assert_called_once()
+ mock_fetcher.get_keys.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
- @defer.inlineCallbacks
- def second_lookup():
- with LoggingContext("12") as context_12:
- context_12.request = "12"
- self.http_client.post_json.reset_mock()
- self.http_client.post_json.return_value = defer.Deferred()
+
+ async def second_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_12")
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.reset_mock()
+ mock_fetcher.get_keys.side_effect = second_lookup_fetch
+ second_lookup_state = [0]
+
+ async def second_lookup():
+ with LoggingContext("context_12") as context_12:
+ context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 1
+ await make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 2
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
-
- d2 = second_lookup()
+ d2 = ensureDeferred(second_lookup())
self.pump()
- self.http_client.post_json.assert_not_called()
+ # the second request should be pending, but the fetcher should not yet have been
+ # called
+ self.assertEqual(second_lookup_state[0], 1)
+ mock_fetcher.get_keys.assert_not_called()
# complete the first request
- persp_deferred.callback(persp_resp)
+ first_lookup_deferred.callback(None)
+
+ # and now both verifications should succeed.
self.get_success(d0)
self.get_success(d2)
@@ -317,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher2.get_keys.assert_called_once()
+@logcontext_clean
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index cc52c3dfac..1a3ccb263d 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -321,3 +321,102 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
per_dest_queue._last_successful_stream_ordering,
event_5.internal_metadata.stream_ordering,
)
+
+ @override_config({"send_federation": True})
+ def test_catch_up_on_synapse_startup(self):
+ """
+ Tests the behaviour of get_catch_up_outstanding_destinations and
+ _wake_destinations_needing_catchup.
+ """
+
+ # list of sorted server names (note that there are more servers than the batch
+ # size used in get_catch_up_outstanding_destinations).
+ server_names = ["server%02d" % number for number in range(42)] + ["zzzerver"]
+
+ # ARRANGE:
+ # - a local user (u1)
+ # - a room which u1 is joined to (and remote users @user:serverXX are
+ # joined to)
+
+ # mark the remotes as online
+ self.is_online = True
+
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_id = self.helper.create_room_as("u1", tok=u1_token)
+
+ for server_name in server_names:
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, "@user:%s" % server_name, "join"
+ )
+ )
+
+ # create an event
+ self.helper.send(room_id, "deary me!", tok=u1_token)
+
+ # ASSERT:
+ # - All servers are up to date so none should have outstanding catch-up
+ outstanding_when_successful = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+ self.assertEqual(outstanding_when_successful, [])
+
+ # ACT:
+ # - Make the remote servers unreachable
+ self.is_online = False
+
+ # - Mark zzzerver as being backed-off from
+ now = self.clock.time_msec()
+ self.get_success(
+ self.hs.get_datastore().set_destination_retry_timings(
+ "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day
+ )
+ )
+
+ # - Send an event
+ self.helper.send(room_id, "can anyone hear me?", tok=u1_token)
+
+ # ASSERT (get_catch_up_outstanding_destinations):
+ # - all remotes are outstanding
+ # - they are returned in batches of 25, in order
+ outstanding_1 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+ )
+
+ self.assertEqual(len(outstanding_1), 25)
+ self.assertEqual(outstanding_1, server_names[0:25])
+
+ outstanding_2 = self.get_success(
+ self.hs.get_datastore().get_catch_up_outstanding_destinations(
+ outstanding_1[-1]
+ )
+ )
+ self.assertNotIn("zzzerver", outstanding_2)
+ self.assertEqual(len(outstanding_2), 17)
+ self.assertEqual(outstanding_2, server_names[25:-1])
+
+ # ACT: call _wake_destinations_needing_catchup
+
+ # patch wake_destination to just count the destinations instead
+ woken = []
+
+ def wake_destination_track(destination):
+ woken.append(destination)
+
+ self.hs.get_federation_sender().wake_destination = wake_destination_track
+
+ # cancel the pre-existing timer for _wake_destinations_needing_catchup
+ # this is because we are calling it manually rather than waiting for it
+ # to be called automatically
+ self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+
+ self.get_success(
+ self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+ )
+
+ # ASSERT (_wake_destinations_needing_catchup):
+ # - all remotes are woken up, save for zzzerver
+ self.assertNotIn("zzzerver", woken)
+ # - all destinations are woken exactly once; they appear once in woken.
+ self.assertCountEqual(woken, server_names[:-1])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 6aa322bf3a..4512c51311 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,6 +36,17 @@ class DeviceTestCase(unittest.HomeserverTestCase):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
+ def test_device_is_created_with_invalid_name(self):
+ self.get_failure(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="foo",
+ initial_device_display_name="a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
def test_device_is_created_if_doesnt_exist(self):
res = self.get_success(
self.handler.check_device_registered(
@@ -213,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
self.reactor.advance(1000)
+
+
+class DehydrationTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ self.handler = hs.get_device_handler()
+ self.registration = hs.get_registration_handler()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_dehydrate_and_rehydrate_device(self):
+ user_id = "@boris:dehydration"
+
+ self.get_success(self.store.register_user(user_id, "foobar"))
+
+ # First check if we can store and fetch a dehydrated device
+ stored_dehydrated_device_id = self.get_success(
+ self.handler.store_dehydrated_device(
+ user_id=user_id,
+ device_data={"device_data": {"foo": "bar"}},
+ initial_device_display_name="dehydrated device",
+ )
+ )
+
+ retrieved_device_id, device_data = self.get_success(
+ self.handler.get_dehydrated_device(user_id=user_id)
+ )
+
+ self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+ self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+ # Create a new login for the user and dehydrated the device
+ device_id, access_token = self.get_success(
+ self.registration.register_device(
+ user_id=user_id, device_id=None, initial_display_name="new device",
+ )
+ )
+
+ # Trying to claim a nonexistent device should throw an error
+ self.get_failure(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id="not the right device ID",
+ ),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # dehydrating the right devices should succeed and change our device ID
+ # to the dehydrated device's ID
+ res = self.get_success(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id=retrieved_device_id,
+ )
+ )
+
+ self.assertEqual(res, {"success": True})
+
+ # make sure that our device ID has changed
+ user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
+
+ self.assertEqual(user_info["device_id"], retrieved_device_id)
+
+ # make sure the device has the display name that was set from the login
+ res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
+
+ self.assertEqual(res["display_name"], "new device")
+
+ # make sure that the device ID that we were initially assigned no longer exists
+ self.get_failure(
+ self.handler.get_device(user_id, device_id),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # make sure that there's no device available for dehydrating now
+ ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+
+ self.assertIsNone(ret)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 210ddcbb88..4e9e3dcbc2 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -30,7 +30,7 @@ from tests import unittest, utils
class E2eKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
@@ -172,6 +172,71 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_fallback_key(self):
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ fallback_key = {"alg1:k1": "key1"}
+ otk = {"alg1:k2": "key2"}
+
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key},
+ )
+ )
+
+ # claiming an OTK when no OTKs are available should return the fallback
+ # key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # claiming an OTK again should return the same fallback key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # if the user uploads a one-time key, the next claim should fetch the
+ # one-time key, and then go back to the fallback
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": otk}
+ )
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ @defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 3362050ce0..7adde9b9de 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -47,7 +47,7 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 89ec5fcb31..b6f436c016 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -21,7 +21,6 @@ from mock import Mock, patch
import attr
import pymacaroons
-from twisted.internet import defer
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider):
async def map_user_attributes(self, userinfo, token):
return {"localpart": userinfo["username"], "display_name": None}
+ # Do not include get_extra_attributes to test backwards compatibility paths.
+
+
+class TestMappingProviderExtra(TestMappingProvider):
+ async def get_extra_attributes(self, userinfo, token):
+ return {"phone": userinfo["phone"]}
+
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
@@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
config = self.default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = config.get("oidc_config", {})
+ oidc_config = {}
oidc_config["enabled"] = True
oidc_config["client_id"] = CLIENT_ID
oidc_config["client_secret"] = CLIENT_SECRET
@@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config["user_mapping_provider"] = {
"module": __name__ + ".TestMappingProvider",
}
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
@@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
- @defer.inlineCallbacks
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ metadata = self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ jwks = self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks())
+ self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
- with self.assertRaises(RuntimeError):
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
self.handler._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@@ -280,9 +286,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
h._validate_metadata,
)
- # Tests for configs that the userinfo endpoint
+ # Tests for configs that require the userinfo endpoint
self.assertFalse(h._uses_userinfo)
- h._scopes = [] # do not request the openid scope
+ self.assertEqual(h._user_profile_method, "auto")
+ h._user_profile_method = "userinfo_endpoint"
+ self.assertTrue(h._uses_userinfo)
+
+ # Revert the profile method and do not request the "openid" scope.
+ h._user_profile_method = "auto"
+ h._scopes = []
self.assertTrue(h._uses_userinfo)
self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
@@ -299,11 +311,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# This should not throw
self.handler._validate_metadata()
- @defer.inlineCallbacks
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
- url = yield defer.ensureDeferred(
+ url = self.get_success(
self.handler.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
@@ -343,20 +354,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
- @defer.inlineCallbacks
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "")
request.args[b"error_description"] = [b"some description"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
- @defer.inlineCallbacks
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@@ -377,7 +386,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "foo",
"preferred_username": "bar",
}
- user_id = UserID("foo", "domain.org")
+ user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
@@ -394,13 +403,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- session = self.handler._generate_oidc_session_token(
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
- request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
@@ -410,10 +418,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -427,13 +435,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(
raises=MappingException()
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
self.handler._auth_handler.complete_sso_login.reset_mock()
@@ -444,10 +452,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
@@ -459,17 +467,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
self.handler._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
- @defer.inlineCallbacks
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
@@ -478,20 +485,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Missing cookie
request.args = {}
request.getCookie.return_value = None
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("missing_session", "No session cookie found")
# Missing session parameter
request.args = {}
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request", "State parameter is missing")
# Invalid cookie
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_session")
# Mismatching session
@@ -504,18 +511,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args = {}
request.args[b"state"] = [b"mismatching state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mismatching_session")
# Valid session
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
- @defer.inlineCallbacks
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@@ -524,7 +530,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ ret = self.get_success(self.handler._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -546,10 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "foo")
- self.assertEqual(exc.exception.error_description, "bar")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "foo")
+ self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
self.http_client.request = simple_async_mock(
@@ -557,9 +562,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
self.http_client.request = simple_async_mock(
@@ -569,17 +573,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "internal_server_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "internal_server_error")
+
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.http_client.request = simple_async_mock(
@@ -587,9 +590,62 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "some_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "some_error")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderExtra"
+ }
+ }
+ }
+ )
+ def test_extra_attributes(self):
+ """
+ Login while using a mapping provider that implements get_extra_attributes.
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "phone": "1234567",
+ }
+ user_id = "@foo:domain.org"
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
+
+ state = "state"
+ client_redirect_url = "http://client/redirect"
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+
+ request.args = {}
+ request.args[b"code"] = [b"code"]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
+ request.getClientIP.return_value = "10.0.0.1"
+
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url, {"phone": "1234567"},
+ )
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
@@ -617,3 +673,38 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
self.assertEqual(mxid, "@test_user_2:test")
+
+ # Test if the mxid is already taken
+ store = self.hs.get_datastore()
+ user3 = UserID.from_string("@test_user_3:test")
+ self.get_success(
+ store.register_user(user_id=user3.to_string(), password_hash=None)
+ )
+ userinfo = {"sub": "test3", "username": "test_user_3"}
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+
+ @override_config({"oidc_config": {"allow_existing_users": True}})
+ def test_map_userinfo_to_existing_user(self):
+ """Existing users can log in with OpenID Connect when allow_existing_users is True."""
+ store = self.hs.get_datastore()
+ user4 = UserID.from_string("@test_user_4:test")
+ self.get_success(
+ store.register_user(user_id=user4.to_string(), password_hash=None)
+ )
+ userinfo = {
+ "sub": "test4",
+ "username": "test_user_4",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_4:test")
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cb7c0ed51a..702c6aa089 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -413,7 +413,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- event_creation_handler.send_nonmember_event(requester, event, context)
+ event_creation_handler.handle_new_client_event(requester, event, context)
)
# Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 5604af3795..212484a7fe 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -318,14 +318,14 @@ class FederationClientTests(HomeserverTestCase):
r = self.successResultOf(d)
self.assertEqual(r.code, 200)
- def test_client_headers_no_body(self):
+ @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
+ def test_timeout_reading_body(self, method_name: str):
"""
If the HTTP request is connected, but gets no response before being
- timed out, it'll give a ResponseNeverReceived.
+ timed out, it'll give a RequestSendFailed with can_retry.
"""
- d = defer.ensureDeferred(
- self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
- )
+ method = getattr(self.cl, method_name)
+ d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000))
self.pump()
@@ -349,7 +349,9 @@ class FederationClientTests(HomeserverTestCase):
self.reactor.advance(10.5)
f = self.failureResultOf(d)
- self.assertIsInstance(f.value, TimeoutError)
+ self.assertIsInstance(f.value, RequestSendFailed)
+ self.assertTrue(f.value.can_retry)
+ self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
def test_client_requires_trailing_slashes(self):
"""
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
new file mode 100644
index 0000000000..a1cf0862d4
--- /dev/null
+++ b/tests/http/test_simple_client.py
@@ -0,0 +1,180 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from netaddr import IPSet
+
+from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
+
+from synapse.http import RequestTimedOutError
+from synapse.http.client import SimpleHttpClient
+from synapse.server import HomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SimpleHttpClientTests(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs: "HomeServer"):
+ # Add a DNS entry for a test server
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self.cl = hs.get_simple_http_client()
+
+ def test_dns_error(self):
+ """
+ If the DNS lookup returns an error, it will bubble up.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv2:8008/foo/bar"))
+ self.pump()
+
+ f = self.failureResultOf(d)
+ self.assertIsInstance(f.value, DNSLookupError)
+
+ def test_client_connection_refused(self):
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+ e = Exception("go away")
+ factory.clientConnectionFailed(None, e)
+ self.pump(0.5)
+
+ f = self.failureResultOf(d)
+
+ self.assertIs(f.value, e)
+
+ def test_client_never_connect(self):
+ """
+ If the HTTP request is not connected and is timed out, it'll give a
+ ConnectingCancelledError or TimeoutError.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_connect_no_response(self):
+ """
+ If the HTTP request is connected, but gets no response before being
+ timed out, it'll give a ResponseNeverReceived.
+ """
+ d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ self.assertEqual(clients[0][0], "1.2.3.4")
+ self.assertEqual(clients[0][1], 8008)
+
+ conn = Mock()
+ client = clients[0][2].buildProtocol(None)
+ client.makeConnection(conn)
+
+ # Deferred is still without a result
+ self.assertNoResult(d)
+
+ # Push by enough to time it out
+ self.reactor.advance(120)
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestTimedOutError)
+
+ def test_client_ip_range_blacklist(self):
+ """Ensure that Synapse does not try to connect to blacklisted IPs"""
+
+ # Add some DNS entries we'll blacklist
+ self.reactor.lookups["internal"] = "127.0.0.1"
+ self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
+ ip_blacklist = IPSet(["127.0.0.0/8", "fe80::/64"])
+
+ cl = SimpleHttpClient(self.hs, ip_blacklist=ip_blacklist)
+
+ # Try making a GET request to a blacklisted IPv4 address
+ # ------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://internal:8008/foo/bar"))
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a POST request to a blacklisted IPv6 address
+ # -------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(
+ cl.post_json_get_json("http://internalv6:8008/foo/bar", {})
+ )
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was unable to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 0)
+
+ # Check that it was due to a blacklisted DNS lookup
+ self.failureResultOf(d, DNSLookupError)
+
+ # Try making a GET request to a non-blacklisted IPv4 address
+ # ----------------------------------------------------------
+ # Make the request
+ d = defer.ensureDeferred(cl.get_json("http://testserv:8008/foo/bar"))
+
+ # Nothing has happened yet
+ self.assertNoResult(d)
+
+ # Move the reactor forwards
+ self.pump(1)
+
+ # Check that it was able to resolve the address
+ clients = self.reactor.tcpClients
+ self.assertNotEqual(len(clients), 0)
+
+ # Connection will still fail as this IP address does not resolve to anything
+ self.failureResultOf(d, RequestTimedOutError)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 04de0b9dbe..7c790bee7d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -13,15 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
from tests.unittest import HomeserverTestCase
class ModuleApiTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
- self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+ self.module_api = homeserver.get_module_api()
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -52,3 +59,50 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the displayname was assigned
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
+
+ def test_public_rooms(self):
+ """Tests that a room can be added and removed from the public rooms list,
+ as well as have its public rooms directory state queried.
+ """
+ # Create a user and room to play with
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+ room_id = self.helper.create_room_as(user_id, tok=tok)
+
+ # The room should not currently be in the public rooms directory
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
+
+ # Let's try adding it to the public rooms directory
+ self.get_success(
+ self.module_api.public_room_list_manager.add_room_to_public_room_list(
+ room_id
+ )
+ )
+
+ # And checking whether it's in there...
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertTrue(is_in_public_rooms)
+
+ # Let's remove it again
+ self.get_success(
+ self.module_api.public_room_list_manager.remove_room_from_public_room_list(
+ room_id
+ )
+ )
+
+ # Should be gone
+ is_in_public_rooms = self.get_success(
+ self.module_api.public_room_list_manager.room_is_in_public_room_list(
+ room_id
+ )
+ )
+ self.assertFalse(is_in_public_rooms)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..81ea985b9f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import Any, Callable, List, Optional, Tuple
import attr
+import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
@@ -27,7 +28,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
+ # Fake in memory Redis server that servers can connect to.
+ self._redis_server = FakeRedisPubSubServer()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["localhost"] = "127.0.0.1"
+
+ # A map from a HS instance to the associated HTTP Site to use for
+ # handling inbound HTTP requests to that instance.
+ self._hs_to_site = {self.hs: self.site}
+
+ if self.hs.config.redis.redis_enabled:
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost", 6379, self.connect_any_redis_attempts,
+ )
- self._worker_hs_to_resource = {}
+ self.hs.get_tcp_replication().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
+ #
+ # Register the master replication listener:
self.reactor.add_tcp_client_callback(
- "1.2.3.4", 8765, self._handle_http_replication_attempt
+ "1.2.3.4",
+ 8765,
+ lambda: self._handle_http_replication_attempt(self.hs, 8765),
)
def create_test_json_resource(self):
@@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
**kwargs
)
+ # If the instance is in the `instance_map` config then workers may try
+ # and send HTTP requests to it, so we register it with
+ # `_handle_http_replication_attempt` like we do with the master HS.
+ instance_name = worker_hs.get_instance_name()
+ instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+ if instance_loc:
+ # Ensure the host is one that has a fake DNS entry.
+ if instance_loc.host not in self.reactor.lookups:
+ raise Exception(
+ "Host does not have an IP for instance_map[%r].host = %r"
+ % (instance_name, instance_loc.host,)
+ )
+
+ self.reactor.add_tcp_client_callback(
+ self.reactor.lookups[instance_loc.host],
+ instance_loc.port,
+ lambda: self._handle_http_replication_attempt(
+ worker_hs, instance_loc.port
+ ),
+ )
+
store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
- )
- server = self.server_factory.buildProtocol(None)
+ # Set up TCP replication between master and the new worker if we don't
+ # have Redis support enabled.
+ if not worker_hs.config.redis_enabled:
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
# Set up a resource for the worker
- resource = ReplicationRestResource(self.hs)
+ resource = ReplicationRestResource(worker_hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
- self._worker_hs_to_resource[worker_hs] = resource
+ self._hs_to_site[worker_hs] = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="{}-{}".format(
+ worker_hs.config.server.server_name, worker_hs.get_instance_name()
+ ),
+ config=worker_hs.config.server.listeners[0],
+ resource=resource,
+ server_version_string="1",
+ )
+
+ if worker_hs.config.redis.redis_enabled:
+ worker_hs.get_tcp_replication().start_replication(worker_hs)
return worker_hs
@@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return config
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
- render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+ render(request, self._hs_to_site[worker_hs].resource, self.reactor)
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump()
- def _handle_http_replication_attempt(self):
- """Handles a connection attempt to the master replication HTTP
- listener.
+ def _handle_http_replication_attempt(self, hs, repl_port):
+ """Handles a connection attempt to the given HS replication HTTP
+ listener on the given port.
"""
# We should have at least one outbound connection attempt, where the
@@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 8765)
+ self.assertEqual(port, repl_port)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
- channel.site = self.site
+ channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
+ def connect_any_redis_attempts(self):
+ """If redis is enabled we need to deal with workers connecting to a
+ redis server. We don't want to use a real Redis server so we use a
+ fake one.
+ """
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
+
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
+
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
+
+ return client_to_server_transport, server_to_client_transport
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +547,105 @@ class _PullToPushProducer:
pass
self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+ """A fake Redis server for pub/sub.
+ """
+
+ def __init__(self):
+ self._subscribers = set()
+
+ def add_subscriber(self, conn):
+ """A connection has called SUBSCRIBE
+ """
+ self._subscribers.add(conn)
+
+ def remove_subscriber(self, conn):
+ """A connection has called UNSUBSCRIBE
+ """
+ self._subscribers.discard(conn)
+
+ def publish(self, conn, channel, msg) -> int:
+ """A connection want to publish a message to subscribers.
+ """
+ for sub in self._subscribers:
+ sub.send(["message", channel, msg])
+
+ return len(self._subscribers)
+
+ def buildProtocol(self, addr):
+ return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+ """A connection from a client talking to the fake Redis server.
+ """
+
+ def __init__(self, server: FakeRedisPubSubServer):
+ self._server = server
+ self._reader = hiredis.Reader()
+
+ def dataReceived(self, data):
+ self._reader.feed(data)
+
+ # We might get multiple messages in one packet.
+ while True:
+ msg = self._reader.gets()
+
+ if msg is False:
+ # No more messages.
+ return
+
+ if not isinstance(msg, list):
+ # Inbound commands should always be a list
+ raise Exception("Expected redis list")
+
+ self.handle_command(msg[0], *msg[1:])
+
+ def handle_command(self, command, *args):
+ """Received a Redis command from the client.
+ """
+
+ # We currently only support pub/sub.
+ if command == b"PUBLISH":
+ channel, message = args
+ num_subscribers = self._server.publish(self, channel, message)
+ self.send(num_subscribers)
+ elif command == b"SUBSCRIBE":
+ (channel,) = args
+ self._server.add_subscriber(self)
+ self.send(["subscribe", channel, 1])
+ else:
+ raise Exception("Unknown command")
+
+ def send(self, msg):
+ """Send a message back to the client.
+ """
+ raw = self.encode(msg).encode("utf-8")
+
+ self.transport.write(raw)
+ self.transport.flush()
+
+ def encode(self, obj):
+ """Encode an object to its Redis format.
+
+ Supports: strings/bytes, integers and list/tuples.
+ """
+
+ if isinstance(obj, bytes):
+ # We assume bytes are just unicode strings.
+ obj = obj.decode("utf-8")
+
+ if isinstance(obj, str):
+ return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+ if isinstance(obj, int):
+ return ":{val}\r\n".format(val=obj)
+ if isinstance(obj, (list, tuple)):
+ items = "".join(self.encode(a) for a in obj)
+ return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+ raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+ def connectionLost(self, reason):
+ self._server.remove_subscriber(self)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 561258a356..c0ee1cfbd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -58,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
- return super(SlavedEventStoreTestCase, self).setUp()
+ return super().setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
@@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
self.replicate()
+
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+ {(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# the membership change is only any use to us if the room is in the
# joined_rooms list.
if membership_changes:
- self.assertEqual(
- joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
)
+ self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
event_id = 0
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..6068d14905
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks event persisting sharding works
+ """
+
+ # Event persister sharding requires postgres (due to needing
+ # `MutliWriterIdGenerator`).
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["redis"] = {"enabled": "true"}
+ conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+ conf["instance_map"] = {
+ "worker1": {"host": "testserv", "port": 1001},
+ "worker2": {"host": "testserv", "port": 1002},
+ }
+ return conf
+
+ def test_basic(self):
+ """Simple test to ensure that multiple rooms can be created and joined,
+ and that different rooms get handled by different instances.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ persisted_on_1 = False
+ persisted_on_2 = False
+
+ store = self.hs.get_datastore()
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Keep making new rooms until we see rooms being persisted on both
+ # workers.
+ for _ in range(10):
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = rseponse["event_id"]
+
+ # The event position includes which instance persisted the event.
+ pos = self.get_success(store.get_position_for_event(event_id))
+
+ persisted_on_1 |= pos.instance_name == "worker1"
+ persisted_on_2 |= pos.instance_name == "worker2"
+
+ if persisted_on_1 and persisted_on_2:
+ break
+
+ self.assertTrue(persisted_on_1)
+ self.assertTrue(persisted_on_2)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index faa7f381a9..92c9058887 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
request, channel = self.make_request(
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
new file mode 100644
index 0000000000..bf79086f78
--- /dev/null
+++ b/tests/rest/admin/test_event_reports.py
@@ -0,0 +1,382 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import report_event
+
+from tests import unittest
+
+
+class EventReportsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ report_event.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id1 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+ self.room_id2 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
+
+ # Two rooms and two users. Every user sends and reports every room event
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.admin_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.admin_user_tok,
+ )
+
+ self.url = "/_synapse/admin/v1/event_reports"
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_default_success(self):
+ """
+ Testing list of reported events
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit(self):
+ """
+ Testing list of reported events with limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_from(self):
+ """
+ Testing list of reported events with a defined starting point (from)
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of reported events with a defined starting point and limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_filter_room(self):
+ """
+ Testing list of reported events with a filter of room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?room_id=%s" % self.room_id1,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_filter_user(self):
+ """
+ Testing list of reported events with a filter of user
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s" % self.other_user,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+
+ def test_filter_user_and_room(self):
+ """
+ Testing list of reported events with a filter of user and room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 5)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_valid_search_order(self):
+ """
+ Testing search order. Order by timestamps.
+ """
+
+ # fetch the most recent first, largest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertGreaterEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ # fetch the oldest first, smallest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertLessEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ def test_invalid_search_order(self):
+ """
+ Testing that a invalid search order returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+
+ def test_limit_is_negative(self):
+ """
+ Testing that a negative list parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_from_is_negative(self):
+ """
+ Testing that a negative from parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ request, channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _create_event_and_report(self, room_id, user_tok):
+ """Create and report events
+ """
+ resp = self.helper.send(room_id, tok=user_tok)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/report/%s" % (room_id, event_id),
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
+ access_token=user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in a event report
+ """
+ for c in content:
+ self.assertIn("id", c)
+ self.assertIn("received_ts", c)
+ self.assertIn("room_id", c)
+ self.assertIn("event_id", c)
+ self.assertIn("user_id", c)
+ self.assertIn("reason", c)
+ self.assertIn("content", c)
+ self.assertIn("sender", c)
+ self.assertIn("room_alias", c)
+ self.assertIn("event_json", c)
+ self.assertIn("score", c["content"])
+ self.assertIn("reason", c["content"])
+ self.assertIn("auth_events", c["event_json"])
+ self.assertIn("type", c["event_json"])
+ self.assertIn("room_id", c["event_json"])
+ self.assertIn("sender", c["event_json"])
+ self.assertIn("content", c["event_json"])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index b8b7758d24..98d0623734 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -22,8 +22,8 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
-from synapse.api.errors import HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login
+from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
@@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self._is_erased("@user:test", False)
+ d = self.store.mark_user_erased("@user:test")
+ self.assertIsNone(self.get_success(d))
+ self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password).
request, channel = self.make_request(
@@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
"""
@@ -995,3 +1000,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
+
+ def _is_erased(self, user_id, expect):
+ """Assert that the user is erased or not
+ """
+ d = self.store.is_user_erased(user_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertFalse(self.get_success(d))
+
+
+class UserMembershipRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = "/_synapse/admin/v1/users/%s/joined_rooms" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to list rooms of an user without authentication.
+ """
+ request, channel = self.make_request("GET", self.url, b"{}")
+ self.render(request)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=other_user_token,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
+
+ request, channel = self.make_request(
+ "GET", url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ def test_get_rooms(self):
+ """
+ Tests that a normal lookup for rooms is successfully
+ """
+ # Create rooms and join
+ other_user_tok = self.login("user", "pass")
+ number_rooms = 5
+ for n in range(number_rooms):
+ self.helper.create_room_as(self.other_user, tok=other_user_tok)
+
+ # Get rooms
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_rooms, channel.json_body["total"])
+ self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
new file mode 100644
index 0000000000..c12518c931
--- /dev/null
+++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, StateMap
+
+from tests import unittest
+
+thread_local = threading.local()
+
+
+class ThirdPartyRulesTestModule:
+ def __init__(self, config, module_api):
+ # keep a record of the "current" rules module, so that the test can patch
+ # it if desired.
+ thread_local.rules_module = self
+
+ async def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return True
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+
+def current_rules_module() -> ThirdPartyRulesTestModule:
+ return thread_local.rules_module
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["third_party_event_rules"] = {
+ "module": __name__ + ".ThirdPartyRulesTestModule",
+ "config": {},
+ }
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ # Create a user and room to play with during the tests
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_third_party_rules(self):
+ """Tests that a forbidden event is forbidden from being sent, but an allowed one
+ can be sent.
+ """
+ # patch the rules module with a Mock which will return False for some event
+ # types
+ async def check(ev, state):
+ return ev.type != "foo.bar.forbidden"
+
+ callback = Mock(spec=[], side_effect=check)
+ current_rules_module().check_event_allowed = callback
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ callback.assert_called_once()
+
+ # there should be various state events in the state arg: do some basic checks
+ state_arg = callback.call_args[0][1]
+ for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+ self.assertIn(k, state_arg)
+ ev = state_arg[k]
+ self.assertEqual(ev.type, k[0])
+ self.assertEqual(ev.state_key, k[1])
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_modify_event(self):
+ """Tests that the module can successfully tweak an event before it is persisted.
+ """
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ ev.content = {"x": "y"}
+ return True
+
+ current_rules_module().check_event_allowed = check
+
+ # now send the event
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ event_id = channel.json_body["event_id"]
+
+ # ... and check that it got modified
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "y")
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
deleted file mode 100644
index 8c24add530..0000000000
--- a/tests/rest/client/third_party_rules.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-
-from tests import unittest
-
-
-class ThirdPartyRulesTestModule:
- def __init__(self, config):
- pass
-
- def check_event_allowed(self, event, context):
- if event.type == "foo.bar.forbidden":
- return False
- else:
- return True
-
- @staticmethod
- def parse_config(config):
- return config
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
- config["third_party_event_rules"] = {
- "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
- "config": {},
- }
-
- self.hs = self.setup_test_homeserver(config=config)
- return self.hs
-
- def test_third_party_rules(self):
- """Tests that a forbidden event is forbidden from being sent, but an allowed one
- can be sent.
- """
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- room_id = self.helper.create_room_as(user_id, tok=tok)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"200", channel.result)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 633b7dbda0..ea5a7f3739 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -21,6 +21,7 @@ from synapse.types import RoomAlias
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.unittest import override_config
class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -67,10 +68,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.ensure_user_joined_room()
self.set_alias_via_directory(400, alias_length=256)
- def test_state_event_in_room(self):
+ @override_config({"default_room_version": 5})
+ def test_state_event_user_in_v5_room(self):
+ """Test that a regular user can add alias events before room v6"""
self.ensure_user_joined_room()
self.set_alias_via_state_event(200)
+ @override_config({"default_room_version": 6})
+ def test_state_event_v6_room(self):
+ """Test that a regular user can *not* add alias events from room v6"""
+ self.ensure_user_joined_room()
+ self.set_alias_via_state_event(403)
+
def test_directory_in_room(self):
self.ensure_user_joined_room()
self.set_alias_via_directory(200)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2668662c9e..5d987a30c7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -7,8 +7,9 @@ from mock import Mock
import jwt
import synapse.rest.admin
+from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from tests import unittest
@@ -748,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
)
+
+
+AS_USER = "as_user_alice"
+
+
+class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ ]
+
+ def register_as_user(self, username):
+ request, channel = self.make_request(
+ b"POST",
+ "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
+ {"username": username},
+ )
+ self.render(request)
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+
+ self.service = ApplicationService(
+ id="unique_identifier",
+ token="some_token",
+ hostname="example.com",
+ sender="@asbot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+ self.another_service = ApplicationService(
+ id="another__identifier",
+ token="another_token",
+ hostname="example.com",
+ sender="@as2bot:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {"regex": r"@as2_user.*", "exclusive": False}
+ ],
+ ApplicationService.NS_ROOMS: [],
+ ApplicationService.NS_ALIASES: [],
+ },
+ )
+
+ self.hs.get_datastore().services_cache.append(self.service)
+ self.hs.get_datastore().services_cache.append(self.another_service)
+ return self.hs
+
+ def test_login_appservice_user(self):
+ """Test that an appservice user can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_user_bot(self):
+ """Test that the appservice bot can use /login
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": self.service.sender},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_login_appservice_wrong_user(self):
+ """Test that non-as users cannot login with the as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_wrong_as(self):
+ """Test that as users cannot login with wrong as token
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(
+ b"POST", LOGIN_URL, params, access_token=self.another_service.token
+ )
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_login_appservice_no_token(self):
+ """Test that users must provide a token when using the appservice
+ login method
+ """
+ self.register_as_user(AS_USER)
+
+ params = {
+ "type": login.LoginRestServlet.APPSERVICE_TYPE,
+ "identifier": {"type": "m.id.user", "user": AS_USER},
+ }
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
+
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0a567b032f..0d809d25d5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -905,6 +905,7 @@ class RoomMessageListTestCase(RoomBase):
first_token = self.get_success(
store.get_topological_token_for_event(first_event_id)
)
+ first_token_str = self.get_success(first_token.to_string(store))
# Send a second message in the room, which won't be removed, and which we'll
# use as the marker to purge events before.
@@ -912,6 +913,7 @@ class RoomMessageListTestCase(RoomBase):
second_token = self.get_success(
store.get_topological_token_for_event(second_event_id)
)
+ second_token_str = self.get_success(second_token.to_string(store))
# Send a third event in the room to ensure we don't fall under any edge case
# due to our marker being the latest forward extremity in the room.
@@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
pagination_handler._purge_history(
purge_id=purge_id,
room_id=self.room_id,
- token=second_token,
+ token=second_token_str,
delete_local_events=True,
)
)
@@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 93f899d861..ae2cd67f35 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -732,6 +732,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
def test_next_link_domain_whitelist(self):
"""Tests next_link parameters must fit the whitelist if provided"""
+
+ # Ensure not providing a next_link parameter still works
+ self._request_token(
+ "something@example.com", "some_secret", next_link=None, expect_code=200,
+ )
+
self._request_token(
"something@example.com",
"some_secret",
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index b090bb974c..dcd65c2a50 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -21,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
def setUp(self):
- super(WellKnownTests, self).setUp()
+ super().setUp()
# replace the JsonResource with a WellKnownResource
self.resource = WellKnownResource(self.hs)
diff --git a/tests/server.py b/tests/server.py
index 61ec670155..f7f5276b21 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -260,7 +260,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return succeed(lookups[name])
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
- super(ThreadedMemoryReactorClock, self).__init__()
+ super().__init__()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
@@ -372,6 +372,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
+ # We've just changed the Databases to run DB transactions on the same
+ # thread, so we need to disable the dedicated thread behaviour.
+ server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
return server
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index cb808d4de4..c905a38930 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
# must be done after inserts
database = hs.get_datastores().databases[0]
self.store = ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
db_config = hs.config.get_single_database()
self.store = TestTransactionStore(
- database, make_conn(db_config, self.engine), hs
+ database, make_conn(db_config, self.engine, "test"), hs
)
def _add_service(self, url, as_token, id):
@@ -413,7 +413,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(TestTransactionStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database, make_conn(database._database_config, database.engine, "test"), hs
)
@defer.inlineCallbacks
@@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
@@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
- database, make_conn(database._database_config, database.engine), hs
+ database,
+ make_conn(database._database_config, database.engine, "test"),
+ hs,
)
e = cm.exception
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 40ba652248..eac7e4dcd2 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -56,6 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
+ fake_engine.in_transaction.return_value = False
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 34ae8c9da7..ecb00f4e02 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -23,7 +23,7 @@ import tests.utils
class DeviceStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
- super(DeviceStoreTestCase, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 949846fe33..3957471f3f 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -52,14 +52,14 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
self.reactor.advance(60 * 60 * 1000)
self.pump(1)
- items = set(
+ items = list(
filter(
lambda x: b"synapse_forward_extremities_" in x,
- generate_latest(REGISTRY).split(b"\n"),
+ generate_latest(REGISTRY, emit_help=False).split(b"\n"),
)
)
- expected = {
+ expected = [
b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
@@ -72,9 +72,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
- b"synapse_forward_extremities_count 3.0",
- b"synapse_forward_extremities_sum 10.0",
- }
-
+ # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9,
+ # "inf" is valid: "this includes variants such as inf"
+ b'synapse_forward_extremities_bucket{le="inf"} 3.0',
+ b"# TYPE synapse_forward_extremities_gcount gauge",
+ b"synapse_forward_extremities_gcount 3.0",
+ b"# TYPE synapse_forward_extremities_gsum gauge",
+ b"synapse_forward_extremities_gsum 10.0",
+ ]
self.assertEqual(items, expected)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 20636fc400..392b08832b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
from synapse.storage.database import DatabasePool
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -43,19 +42,23 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
)
- return self.get_success(self.db_pool.runWithConnection(_create))
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
"""Insert N rows as the given instance, inserting with stream IDs pulled
@@ -68,6 +71,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@@ -81,6 +91,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, stream_id, stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
@@ -111,7 +128,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await id_gen.get_next() as stream_id:
+ async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
@@ -139,10 +156,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
ctx3 = self.get_success(id_gen.get_next())
ctx4 = self.get_success(id_gen.get_next())
- s1 = ctx1.__enter__()
- s2 = ctx2.__enter__()
- s3 = ctx3.__enter__()
- s4 = ctx4.__enter__()
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+ s3 = self.get_success(ctx3.__aenter__())
+ s4 = self.get_success(ctx4.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
@@ -152,22 +169,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx2.__exit__(None, None, None)
+ self.get_success(ctx2.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx1.__exit__(None, None, None)
+ self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
- ctx4.__exit__(None, None, None)
+ self.get_success(ctx4.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
- ctx3.__exit__(None, None, None)
+ self.get_success(ctx3.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
@@ -179,8 +196,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_rows("first", 3)
self._insert_rows("second", 4)
- first_id_gen = self._create_id_generator("first")
- second_id_gen = self._create_id_generator("second")
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
@@ -190,7 +207,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await first_id_gen.get_next() as stream_id:
+ async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
@@ -208,7 +225,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# stream ID
async def _get_next_async():
- with await second_id_gen.get_next() as stream_id:
+ async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
self.assertEqual(
@@ -262,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -300,14 +317,18 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
- with self.get_success(id_gen.get_next()) as stream_id:
- self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
@@ -315,6 +336,115 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+ def test_restart_during_out_of_order_persistence(self):
+ """Test that restarting a process while another process is writing out
+ of order updates are handled correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # Persist two rows at once
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # We finish persisting the second row before restart
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ # We simulate a restart of another worker by just creating a new ID gen.
+ id_gen_worker = self._create_id_generator("worker")
+
+ # Restarted worker should not see the second persisted row
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+ self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+
+ # Now if we persist the first row then both instances should jump ahead
+ # correctly.
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ id_gen_worker.advance("master", 9)
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+
+ def test_writer_config_change(self):
+ """Test that changing the writer config correctly works.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ # Initial config has two writers
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
+
+ # New config removes one of the configs. Note that if the writer is
+ # removed from config we assume that it has been shut down and has
+ # finished persisting, hence why the persisted upto position is 5.
+ id_gen_2 = self._create_id_generator("second", writers=["second"])
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+ self.assertEqual(id_gen_2.get_current_token_for_writer("second"), 5)
+
+ # This config points to a single, previously unused writer.
+ id_gen_3 = self._create_id_generator("third", writers=["third"])
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+
+ # For new writers we assume their initial position to be the current
+ # persisted up to position. This stops Synapse from doing a full table
+ # scan when a new writer comes along.
+ self.assertEqual(id_gen_3.get_current_token_for_writer("third"), 5)
+
+ id_gen_4 = self._create_id_generator("fourth", writers=["third"])
+ self.assertEqual(id_gen_4.get_current_token_for_writer("third"), 5)
+
+ # Check that we get a sane next stream ID with this new config.
+
+ async def _get_next_async():
+ async with id_gen_3.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+
+ self.get_success(_get_next_async())
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+
+ # If we add back the old "first" then we shouldn't see the persisted up
+ # to position revert back to 3.
+ id_gen_5 = self._create_id_generator("five", writers=["first", "third"])
+ self.assertEqual(id_gen_5.get_persisted_upto_position(), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
+
+ def test_sequence_consistency(self):
+ """Test that we error out if the table and sequence diverges.
+ """
+
+ # Prefill with some rows
+ self._insert_row_with_id("master", 3)
+
+ # Now we add a row *without* updating the stream ID
+ def _insert(txn):
+ txn.execute("INSERT INTO foobar VALUES (26, 'master')")
+
+ self.get_success(self.db_pool.runInteraction("_insert", _insert))
+
+ # Creating the ID gen should error
+ with self.assertRaises(IncorrectDatabaseSetup):
+ self._create_id_generator("first")
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
@@ -341,16 +471,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
positive=False,
)
@@ -364,6 +498,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, -stream_id, -stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
@@ -373,16 +514,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
id_gen = self._create_id_generator()
- with self.get_success(id_gen.get_next()) as stream_id:
- self._insert_row("master", stream_id)
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
- with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
- for stream_id in stream_ids:
- self._insert_row("master", stream_id)
+ async def _get_next_async2():
+ async with id_gen.get_next_mult(3) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
@@ -399,21 +546,27 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
- id_gen_1 = self._create_id_generator("first")
- id_gen_2 = self._create_id_generator("second")
+ id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
+ id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
- with self.get_success(id_gen_1.get_next()) as stream_id:
- self._insert_row("first", stream_id)
- id_gen_2.advance("first", stream_id)
+ async def _get_next_async():
+ async with id_gen_1.get_next() as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
- with self.get_success(id_gen_2.get_next()) as stream_id:
- self._insert_row("second", stream_id)
- id_gen_1.advance("second", stream_id)
+ async def _get_next_async2():
+ async with id_gen_2.get_next() as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 643072bbaf..8d97b6d4cd 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -137,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)
+ def test_appservice_user_not_counted_in_mau(self):
+ self.get_success(
+ self.store.register_user(
+ user_id="@appservice_user:server", appservice_id="wibble"
+ )
+ )
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
+ d = self.store.upsert_monthly_active_user("@appservice_user:server")
+ self.get_success(d)
+
+ count = self.get_success(self.store.get_monthly_active_count())
+ self.assertEqual(count, 0)
+
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
@@ -383,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, 4)
+ self.assertEqual(count, 1)
d = self.store.get_monthly_active_count_by_service()
result = self.get_success(d)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 918387733b..cc1f3c53c5 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = self.get_success(
+ token = self.get_success(
store.get_topological_token_for_event(last["event_id"])
)
+ token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
- self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
+ self.get_success(
+ storage.purge_events.purge_history(self.room_id, token_str, True)
+ )
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
@@ -74,12 +77,10 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = self.get_success(
+ token = self.get_success(
storage.get_topological_token_for_event(last["event_id"])
)
- event = "t{}-{}".format(
- *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
- )
+ event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 7657bddea5..e7aed092c2 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -17,7 +17,7 @@ import resource
import mock
-from synapse.app.homeserver import phone_stats_home
+from synapse.app.phone_stats_home import phone_stats_home
from tests.unittest import HomeserverTestCase
diff --git a/tests/test_state.py b/tests/test_state.py
index 2d58467932..80b0ccbc40 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -125,7 +125,7 @@ class StateGroupStore:
class DictObj(dict):
def __init__(self, **kwargs):
- super(DictObj, self).__init__(kwargs)
+ super().__init__(kwargs)
self.__dict__ = self
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 2d96b0fa8d..fdfb840b62 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -29,8 +29,7 @@ class ToTwistedHandler(logging.Handler):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level),
- log_entry.replace("{", r"(").replace("}", r")"),
+ twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
diff --git a/tests/unittest.py b/tests/unittest.py
index 128dd4e19c..5c87f6097e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import hashlib
import hmac
@@ -23,11 +22,12 @@ import logging
import time
from typing import Optional, Tuple, Type, TypeVar, Union
-from mock import Mock
+from mock import Mock, patch
from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -92,7 +92,7 @@ class TestCase(unittest.TestCase):
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
- super(TestCase, self).__init__(methodName, *args, **kwargs)
+ super().__init__(methodName, *args, **kwargs)
method = getattr(self, methodName)
@@ -169,6 +169,19 @@ def INFO(target):
return target
+def logcontext_clean(target):
+ """A decorator which marks the TestCase or method as 'logcontext_clean'
+
+ ... ie, any logcontext errors should cause a test failure
+ """
+
+ def logcontext_error(msg):
+ raise AssertionError("logcontext error: %s" % (msg))
+
+ patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
+ return patcher(target)
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -228,7 +241,7 @@ class HomeserverTestCase(TestCase):
# create a site to wrap the resource.
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
- site_tag="test",
+ site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
@@ -463,6 +476,35 @@ class HomeserverTestCase(TestCase):
self.pump()
return self.failureResultOf(d, exc)
+ def get_success_or_raise(self, d, by=0.0):
+ """Drive deferred to completion and return result or raise exception
+ on failure.
+ """
+
+ if inspect.isawaitable(d):
+ deferred = ensureDeferred(d)
+ if not isinstance(deferred, Deferred):
+ return d
+
+ results = [] # type: list
+ deferred.addBoth(results.append)
+
+ self.pump(by=by)
+
+ if not results:
+ self.fail(
+ "Success result expected on {!r}, found no result instead".format(
+ deferred
+ )
+ )
+
+ result = results[0]
+
+ if isinstance(result, Failure):
+ result.raiseException()
+
+ return result
+
def register_user(self, username, password, admin=False):
"""
Register a user. Requires the Admin API be registered.
@@ -566,7 +608,9 @@ class HomeserverTestCase(TestCase):
if soft_failed:
event.internal_metadata.soft_failed = True
- self.get_success(event_creator.send_nonmember_event(requester, event, context))
+ self.get_success(
+ event_creator.handle_new_client_event(requester, event, context)
+ )
return event.event_id
diff --git a/tests/utils.py b/tests/utils.py
index 4673872f88..af563ffe0f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -38,6 +38,7 @@ from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -88,6 +89,7 @@ def setupdb():
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
+ db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
@@ -276,7 +278,7 @@ def setup_test_homeserver(
hs.setup()
if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
+ hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
database = hs.get_datastores().databases[0]
diff --git a/tox.ini b/tox.ini
index df473bd234..4d132eff4c 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,13 +2,12 @@
envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort
[base]
+extras = test
deps =
- mock
python-subunit
junitxml
coverage
coverage-enable-subprocess
- parameterized
# cyptography 2.2 requires setuptools >= 18.5
#
@@ -36,7 +35,7 @@ setenv =
[testenv]
deps =
{[base]deps}
-extras = all
+extras = all, test
whitelist_externals =
sh
@@ -84,7 +83,6 @@ deps =
# Old automat version for Twisted
Automat == 0.3.0
- mock
lxml
coverage
coverage-enable-subprocess
@@ -97,7 +95,7 @@ commands =
/bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
# Install Synapse itself. This won't update any libraries.
- pip install -e .
+ pip install -e ".[test]"
{envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
@@ -118,20 +116,14 @@ commands =
check-manifest
[testenv:check_codestyle]
-skip_install = True
-deps =
- flake8
- flake8-comprehensions
- # We pin so that our tests don't start failing on new releases of black.
- black==19.10b0
+extras = lint
commands =
python -m black --check --diff .
/bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}"
{toxinidir}/scripts-dev/config-lint.sh
[testenv:check_isort]
-skip_install = True
-deps = isort==5.0.3
+extras = lint
commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts"
[testenv:check-newsfragment]
|