diff --git a/CHANGES.md b/CHANGES.md
index b4e1d25fe0..fcd782fa94 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,9 @@
+Unreleased
+==========
+
+Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically.
+
+
Synapse 1.26.0 (2021-01-27)
===========================
diff --git a/UPGRADE.rst b/UPGRADE.rst
index d09dbd4e21..eea0322695 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -85,6 +85,43 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
+Upgrading to v1.27.0
+====================
+
+Changes to HTML templates
+-------------------------
+
+The HTML templates for SSO and email notifications now have `Jinja2's autoescape <https://jinja.palletsprojects.com/en/2.11.x/api/#autoescaping>`_
+enabled for files ending in ``.html``, ``.htm``, and ``.xml``. If you have customised
+these templates and see issues when viewing them you might need to update them.
+It is expected that most configurations will need no changes.
+
+If you have customised the templates *names* for these templates, it is recommended
+to verify they end in ``.html`` to ensure autoescape is enabled.
+
+The above applies to the following templates:
+
+* ``add_threepid.html``
+* ``add_threepid_failure.html``
+* ``add_threepid_success.html``
+* ``notice_expiry.html``
+* ``notice_expiry.html``
+* ``notif_mail.html`` (which, by default, includes ``room.html`` and ``notif.html``)
+* ``password_reset.html``
+* ``password_reset_confirmation.html``
+* ``password_reset_failure.html``
+* ``password_reset_success.html``
+* ``registration.html``
+* ``registration_failure.html``
+* ``registration_success.html``
+* ``sso_account_deactivated.html``
+* ``sso_auth_bad_user.html``
+* ``sso_auth_confirm.html``
+* ``sso_auth_success.html``
+* ``sso_error.html``
+* ``sso_login_idp_picker.html``
+* ``sso_redirect_confirm.html``
+
Upgrading to v1.26.0
====================
diff --git a/changelog.d/9045.misc b/changelog.d/9045.misc
new file mode 100644
index 0000000000..7f1886a0de
--- /dev/null
+++ b/changelog.d/9045.misc
@@ -0,0 +1 @@
+Add tests to `test_user.UsersListTestCase` for List Users Admin API.
\ No newline at end of file
diff --git a/changelog.d/9062.feature b/changelog.d/9062.feature
new file mode 100644
index 0000000000..8b950fa062
--- /dev/null
+++ b/changelog.d/9062.feature
@@ -0,0 +1 @@
+Add admin API for getting and deleting forward extremities for a room.
diff --git a/changelog.d/9121.bugfix b/changelog.d/9121.bugfix
new file mode 100644
index 0000000000..a566878ec0
--- /dev/null
+++ b/changelog.d/9121.bugfix
@@ -0,0 +1 @@
+Fix spurious errors in logs when deleting a non-existant pusher.
diff --git a/changelog.d/9163.bugfix b/changelog.d/9163.bugfix
new file mode 100644
index 0000000000..c51cf6ca80
--- /dev/null
+++ b/changelog.d/9163.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).
diff --git a/changelog.d/9164.bugfix b/changelog.d/9164.bugfix
new file mode 100644
index 0000000000..1c54a256c1
--- /dev/null
+++ b/changelog.d/9164.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding.
diff --git a/changelog.d/9165.bugfix b/changelog.d/9165.bugfix
new file mode 100644
index 0000000000..58db22f484
--- /dev/null
+++ b/changelog.d/9165.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push.
diff --git a/changelog.d/9176.misc b/changelog.d/9176.misc
new file mode 100644
index 0000000000..9c41d7b0f9
--- /dev/null
+++ b/changelog.d/9176.misc
@@ -0,0 +1 @@
+Speed up chain cover calculation when persisting a batch of state events at once.
diff --git a/changelog.d/9188.misc b/changelog.d/9188.misc
new file mode 100644
index 0000000000..7820d09cd0
--- /dev/null
+++ b/changelog.d/9188.misc
@@ -0,0 +1 @@
+Speed up batch insertion when using PostgreSQL.
diff --git a/changelog.d/9190.misc b/changelog.d/9190.misc
new file mode 100644
index 0000000000..1b0cc56a92
--- /dev/null
+++ b/changelog.d/9190.misc
@@ -0,0 +1 @@
+Improve performance of concurrent use of `StreamIDGenerators`.
diff --git a/changelog.d/9191.misc b/changelog.d/9191.misc
new file mode 100644
index 0000000000..b4bc6be13a
--- /dev/null
+++ b/changelog.d/9191.misc
@@ -0,0 +1 @@
+Add some missing source directories to the automatic linting script.
\ No newline at end of file
diff --git a/changelog.d/9198.misc b/changelog.d/9198.misc
new file mode 100644
index 0000000000..a6cb77fbb2
--- /dev/null
+++ b/changelog.d/9198.misc
@@ -0,0 +1 @@
+Precompute joined hosts and store in Redis.
diff --git a/changelog.d/9199.removal b/changelog.d/9199.removal
new file mode 100644
index 0000000000..fbd2916cbf
--- /dev/null
+++ b/changelog.d/9199.removal
@@ -0,0 +1 @@
+The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`.
diff --git a/changelog.d/9200.misc b/changelog.d/9200.misc
new file mode 100644
index 0000000000..5f239ff9da
--- /dev/null
+++ b/changelog.d/9200.misc
@@ -0,0 +1 @@
+Clean-up template loading code.
diff --git a/changelog.d/9209.feature b/changelog.d/9209.feature
new file mode 100644
index 0000000000..ec926e8eb4
--- /dev/null
+++ b/changelog.d/9209.feature
@@ -0,0 +1 @@
+Add an admin API endpoint for shadow-banning users.
diff --git a/changelog.d/9218.bugfix b/changelog.d/9218.bugfix
new file mode 100644
index 0000000000..577fff5497
--- /dev/null
+++ b/changelog.d/9218.bugfix
@@ -0,0 +1 @@
+Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.
diff --git a/changelog.d/9222.misc b/changelog.d/9222.misc
new file mode 100644
index 0000000000..37490717b3
--- /dev/null
+++ b/changelog.d/9222.misc
@@ -0,0 +1 @@
+Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting.
\ No newline at end of file
diff --git a/changelog.d/9223.misc b/changelog.d/9223.misc
new file mode 100644
index 0000000000..9d44b621c9
--- /dev/null
+++ b/changelog.d/9223.misc
@@ -0,0 +1 @@
+Add type hints to handlers code.
diff --git a/changelog.d/9227.misc b/changelog.d/9227.misc
new file mode 100644
index 0000000000..a6cb77fbb2
--- /dev/null
+++ b/changelog.d/9227.misc
@@ -0,0 +1 @@
+Precompute joined hosts and store in Redis.
diff --git a/changelog.d/9229.bugfix b/changelog.d/9229.bugfix
new file mode 100644
index 0000000000..3ed32291de
--- /dev/null
+++ b/changelog.d/9229.bugfix
@@ -0,0 +1 @@
+Fix a bug where `None` was passed to Synapse modules instead of an empty dictionary if an empty module `config` block was provided in the homeserver config.
\ No newline at end of file
diff --git a/changelog.d/9232.misc b/changelog.d/9232.misc
new file mode 100644
index 0000000000..9d44b621c9
--- /dev/null
+++ b/changelog.d/9232.misc
@@ -0,0 +1 @@
+Add type hints to handlers code.
diff --git a/changelog.d/9235.bugfix b/changelog.d/9235.bugfix
new file mode 100644
index 0000000000..7809c8673b
--- /dev/null
+++ b/changelog.d/9235.bugfix
@@ -0,0 +1 @@
+Fix a bug in the `make_room_admin` admin API where it failed if the admin with the greatest power level was not in the room. Contributed by Pankaj Yadav.
diff --git a/changelog.d/9238.feature b/changelog.d/9238.feature
new file mode 100644
index 0000000000..143a3e14f5
--- /dev/null
+++ b/changelog.d/9238.feature
@@ -0,0 +1 @@
+Add ratelimited to 3PID `/requestToken` API.
diff --git a/changelog.d/9244.doc b/changelog.d/9244.doc
new file mode 100644
index 0000000000..2ad81429fc
--- /dev/null
+++ b/changelog.d/9244.doc
@@ -0,0 +1 @@
+Add notes on integrating with Facebook for SSO login.
diff --git a/changelog.d/9254.misc b/changelog.d/9254.misc
new file mode 100644
index 0000000000..b79b9abbd3
--- /dev/null
+++ b/changelog.d/9254.misc
@@ -0,0 +1 @@
+Fix Debian package building on Ubuntu 16.04 LTS (Xenial).
diff --git a/changelog.d/9255.misc b/changelog.d/9255.misc
new file mode 100644
index 0000000000..f723b8ec4f
--- /dev/null
+++ b/changelog.d/9255.misc
@@ -0,0 +1 @@
+Minor performance improvement during TLS handshake.
diff --git a/changelog.d/9258.feature b/changelog.d/9258.feature
new file mode 100644
index 0000000000..0028f42d26
--- /dev/null
+++ b/changelog.d/9258.feature
@@ -0,0 +1 @@
+Add ratelimits to invites in rooms and to specific users.
diff --git a/changelog.d/9260.misc b/changelog.d/9260.misc
new file mode 100644
index 0000000000..0150e10ea9
--- /dev/null
+++ b/changelog.d/9260.misc
@@ -0,0 +1 @@
+Refactor the generation of summary text for email notifications.
diff --git a/changelog.d/9265.bugfix b/changelog.d/9265.bugfix
new file mode 100644
index 0000000000..34f7bd8ddd
--- /dev/null
+++ b/changelog.d/9265.bugfix
@@ -0,0 +1 @@
+Prevent password hashes from getting dropped if a client failed threepid validation during a User Interactive Auth stage. Removes a workaround for an ancient bug in Riot Web <v0.7.4.
\ No newline at end of file
diff --git a/changelog.d/9270.misc b/changelog.d/9270.misc
new file mode 100644
index 0000000000..908e5ee78b
--- /dev/null
+++ b/changelog.d/9270.misc
@@ -0,0 +1 @@
+Restore PyPy compatibility by not calling CPython-specific GC methods when under PyPy.
diff --git a/changelog.d/9283.feature b/changelog.d/9283.feature
new file mode 100644
index 0000000000..54f133a064
--- /dev/null
+++ b/changelog.d/9283.feature
@@ -0,0 +1 @@
+Add phone home stats for encrypted messages.
diff --git a/debian/build_virtualenv b/debian/build_virtualenv
index cbdde93f96..cf19084a9f 100755
--- a/debian/build_virtualenv
+++ b/debian/build_virtualenv
@@ -33,11 +33,13 @@ esac
# Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather
# than the 2/3 compatible `virtualenv`.
+# Pin pip to 20.3.4 to fix breakage in 21.0 on py3.5 (xenial)
+
dh_virtualenv \
--install-suffix "matrix-synapse" \
--builtin-venv \
--python "$SNAKE" \
- --upgrade-pip \
+ --upgrade-pip-to="20.3.4" \
--preinstall="lxml" \
--preinstall="mock" \
--extra-pip-arg="--no-cache-dir" \
diff --git a/debian/changelog b/debian/changelog
index 1c6308e3a2..1a421a85bd 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,8 +1,18 @@
-matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium
+matrix-synapse-py3 (1.26.0+nmu1) UNRELEASED; urgency=medium
+ * Fix build on Ubuntu 16.04 LTS (Xenial).
+
+ -- Dan Callahan <danc@element.io> Thu, 28 Jan 2021 16:21:03 +0000
+
+matrix-synapse-py3 (1.26.0) stable; urgency=medium
+
+ [ Richard van der Hoff ]
* Remove dependency on `python3-distutils`.
- -- Richard van der Hoff <richard@matrix.org> Fri, 15 Jan 2021 12:44:19 +0000
+ [ Synapse Packaging team ]
+ * New synapse release 1.26.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Wed, 27 Jan 2021 12:43:35 -0500
matrix-synapse-py3 (1.25.0) stable; urgency=medium
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index e529293803..0d74630370 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -27,6 +27,7 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
wget
# fetch and unpack the package
+# TODO: Upgrade to 1.2.2 once xenial is dropped
RUN mkdir /dh-virtualenv
RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 9e560003a9..f34cec1ff7 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -9,6 +9,7 @@
* [Response](#response)
* [Undoing room shutdowns](#undoing-room-shutdowns)
- [Make Room Admin API](#make-room-admin-api)
+- [Forward Extremities Admin API](#forward-extremities-admin-api)
# List Room API
@@ -511,3 +512,55 @@ optionally be specified, e.g.:
"user_id": "@foo:example.com"
}
```
+
+# Forward Extremities Admin API
+
+Enables querying and deleting forward extremities from rooms. When a lot of forward
+extremities accumulate in a room, performance can become degraded. For details, see
+[#1760](https://github.com/matrix-org/synapse/issues/1760).
+
+## Check for forward extremities
+
+To check the status of forward extremities for a room:
+
+```
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned:
+
+```json
+{
+ "count": 1,
+ "results": [
+ {
+ "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh",
+ "state_group": 439,
+ "depth": 123,
+ "received_ts": 1611263016761
+ }
+ ]
+}
+```
+
+## Deleting forward extremities
+
+**WARNING**: Please ensure you know what you're doing and have read
+the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
+Under no situations should this API be executed as an automated maintenance task!
+
+If a room has lots of forward extremities, the extra can be
+deleted as follows:
+
+```
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned, indicating the amount of forward extremities
+that were deleted.
+
+```json
+{
+ "deleted": 1
+}
+```
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index b3d413cf57..1eb674939e 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -760,3 +760,33 @@ The following fields are returned in the JSON response body:
- ``total`` - integer - Number of pushers.
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
+
+Shadow-banning users
+====================
+
+Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
+A shadow-banned users receives successful responses to their client-server API requests,
+but the events are not propagated into rooms. This can be an effective tool as it
+(hopefully) takes longer for the user to realise they are being moderated before
+pivoting to another account.
+
+Shadow-banning a user should be used as a tool of last resort and may lead to confusing
+or broken behaviour for the client. A shadow-banned user will not receive any
+notification and it is generally more appropriate to ban or kick abusive users.
+A shadow-banned user will be unable to contact anyone on the server.
+
+The API is::
+
+ POST /_synapse/admin/v1/users/<user_id>/shadow_ban
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+ be local.
diff --git a/docs/openid.md b/docs/openid.md
index f01f46d326..4ba3559e38 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -44,7 +44,7 @@ as follows:
To enable the OpenID integration, you should then add a section to the `oidc_providers`
setting in your configuration file (or uncomment one of the existing examples).
-See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
+See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
the text below for example configurations for specific providers.
## Sample configs
@@ -52,11 +52,11 @@ the text below for example configurations for specific providers.
Here are a few configs for providers that should work with Synapse.
### Microsoft Azure Active Directory
-Azure AD can act as an OpenID Connect Provider. Register a new application under
+Azure AD can act as an OpenID Connect Provider. Register a new application under
*App registrations* in the Azure AD management console. The RedirectURI for your
application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback`
-Go to *Certificates & secrets* and register a new client secret. Make note of your
+Go to *Certificates & secrets* and register a new client secret. Make note of your
Directory (tenant) ID as it will be used in the Azure links.
Edit your Synapse config file and change the `oidc_config` section:
@@ -118,7 +118,7 @@ oidc_providers:
```
### [Keycloak][keycloak-idp]
-[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
+[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
@@ -194,7 +194,7 @@ Synapse config:
```yaml
oidc_providers:
- - idp_id: auth0
+ - idp_id: auth0
idp_name: Auth0
issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
client_id: "your-client-id" # TO BE FILLED
@@ -310,3 +310,46 @@ oidc_providers:
localpart_template: '{{ user.nickname }}'
display_name_template: '{{ user.name }}'
```
+
+### Facebook
+
+Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant
+one so requires a little more configuration.
+
+0. You will need a Facebook developer account. You can register for one
+ [here](https://developers.facebook.com/async/registration/).
+1. On the [apps](https://developers.facebook.com/apps/) page of the developer
+ console, "Create App", and choose "Build Connected Experiences".
+2. Once the app is created, add "Facebook Login" and choose "Web". You don't
+ need to go through the whole form here.
+3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings".
+ * Add `[synapse public baseurl]/_synapse/oidc/callback` as an OAuth Redirect
+ URL.
+4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID"
+ and "App Secret" for use below.
+
+Synapse config:
+
+```yaml
+ - idp_id: facebook
+ idp_name: Facebook
+ idp_brand: "org.matrix.facebook" # optional: styling hint for clients
+ discover: false
+ issuer: "https://facebook.com"
+ client_id: "your-client-id" # TO BE FILLED
+ client_secret: "your-client-secret" # TO BE FILLED
+ scopes: ["openid", "email"]
+ authorization_endpoint: https://facebook.com/dialog/oauth
+ token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token
+ user_profile_method: "userinfo_endpoint"
+ userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture"
+ user_mapping_provider:
+ config:
+ subject_claim: "id"
+ display_name_template: "{{ user.name }}"
+```
+
+Relevant documents:
+ * https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow
+ * Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/
+ * Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 15e9746696..dd2981717d 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -824,6 +824,9 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
+# - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+# - two for ratelimiting how often invites can be sent in a room or to a
+# specific user.
#
# The defaults are as shown below.
#
@@ -857,7 +860,18 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# remote:
# per_second: 0.01
# burst_count: 3
-
+#
+#rc_3pid_validation:
+# per_second: 0.003
+# burst_count: 5
+#
+#rc_invites:
+# per_room:
+# per_second: 0.3
+# burst_count: 10
+# per_user:
+# per_second: 0.003
+# burst_count: 5
# Ratelimiting settings for incoming federation
#
@@ -1893,10 +1907,6 @@ cas_config:
#
#server_url: "https://cas-server.com"
- # The public URL of the homeserver.
- #
- #service_url: "https://homeserver.domain.com:8448"
-
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.
diff --git a/docs/workers.md b/docs/workers.md
index 0da805c333..c36549c621 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -40,6 +40,9 @@ which relays replication commands between processes. This can give a significant
cpu saving on the main process and will be a prerequisite for upcoming
performance improvements.
+If Redis support is enabled Synapse will use it as a shared cache, as well as a
+pub/sub mechanism.
+
See the [Architectural diagram](#architectural-diagram) section at the end for
a visualisation of what this looks like.
diff --git a/mypy.ini b/mypy.ini
index bd99069c81..68a4533973 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -23,39 +23,7 @@ files =
synapse/events/validator.py,
synapse/events/spamcheck.py,
synapse/federation,
- synapse/handlers/_base.py,
- synapse/handlers/account_data.py,
- synapse/handlers/account_validity.py,
- synapse/handlers/admin.py,
- synapse/handlers/appservice.py,
- synapse/handlers/auth.py,
- synapse/handlers/cas_handler.py,
- synapse/handlers/deactivate_account.py,
- synapse/handlers/device.py,
- synapse/handlers/devicemessage.py,
- synapse/handlers/directory.py,
- synapse/handlers/events.py,
- synapse/handlers/federation.py,
- synapse/handlers/identity.py,
- synapse/handlers/initial_sync.py,
- synapse/handlers/message.py,
- synapse/handlers/oidc_handler.py,
- synapse/handlers/pagination.py,
- synapse/handlers/password_policy.py,
- synapse/handlers/presence.py,
- synapse/handlers/profile.py,
- synapse/handlers/read_marker.py,
- synapse/handlers/receipts.py,
- synapse/handlers/register.py,
- synapse/handlers/room.py,
- synapse/handlers/room_list.py,
- synapse/handlers/room_member.py,
- synapse/handlers/room_member_worker.py,
- synapse/handlers/saml_handler.py,
- synapse/handlers/sso.py,
- synapse/handlers/sync.py,
- synapse/handlers/user_directory.py,
- synapse/handlers/ui_auth,
+ synapse/handlers,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
@@ -194,3 +162,9 @@ ignore_missing_imports = True
[mypy-hiredis]
ignore_missing_imports = True
+
+[mypy-josepy.*]
+ignore_missing_imports = True
+
+[mypy-txacme.*]
+ignore_missing_imports = True
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index f328ab57d5..fe2965cd36 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -80,7 +80,8 @@ else
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
- files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
+ # Note: this list aims the mirror the one in tox.ini
+ files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
fi
fi
diff --git a/setup.py b/setup.py
index ddbe9f511a..99425d52de 100755
--- a/setup.py
+++ b/setup.py
@@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
#
# We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [
- "isort==5.0.3",
+ "isort==5.7.0",
"black==19.10b0",
"flake8-comprehensions",
"flake8",
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index bfac6840e6..618548a305 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -15,13 +15,23 @@
"""Contains *incomplete* type hints for txredisapi.
"""
-
-from typing import List, Optional, Type, Union
+from typing import Any, List, Optional, Type, Union
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
+ async def ping(self) -> None: ...
+ async def set(
+ self,
+ key: str,
+ value: Any,
+ expire: Optional[int] = None,
+ pexpire: Optional[int] = None,
+ only_if_not_exists: bool = False,
+ only_if_exists: bool = False,
+ ) -> None: ...
+ async def get(self, key: str) -> Any: ...
-class SubscriberProtocol:
+class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
@@ -40,14 +50,13 @@ def lazyConnection(
convertNumbers: bool = ...,
) -> RedisProtocol: ...
-class SubscriberFactory:
- def buildProtocol(self, addr): ...
-
class ConnectionHandler: ...
class RedisFactory:
continueTrying: bool
handler: RedisProtocol
+ pool: List[RedisProtocol]
+ replyTimeout: Optional[int]
def __init__(
self,
uuid: str,
@@ -60,3 +69,7 @@ class RedisFactory:
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
): ...
+ def buildProtocol(self, addr) -> RedisProtocol: ...
+
+class SubscriberFactory(RedisFactory):
+ def __init__(self): ...
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 395e202b89..9840a9d55b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -16,6 +16,7 @@
import gc
import logging
import os
+import platform
import signal
import socket
import sys
@@ -339,7 +340,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
# rest of time. Doing so means less work each GC (hopefully).
#
# This only works on Python 3.7
- if sys.version_info >= (3, 7):
+ if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
gc.collect()
gc.freeze()
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index c38cf8231f..8f86cecb76 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+ daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+ stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
+ stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
+ daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+ stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+ daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
r30_results = await hs.get_datastore().count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
- daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 35e5594b73..a851f8801d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -203,11 +203,28 @@ class Config:
with open(file_path) as file_stream:
return file_stream.read()
+ def read_template(self, filename: str) -> jinja2.Template:
+ """Load a template file from disk.
+
+ This function will attempt to load the given template from the default Synapse
+ template directory.
+
+ Files read are treated as Jinja templates. The templates is not rendered yet
+ and has autoescape enabled.
+
+ Args:
+ filename: A template filename to read.
+
+ Raises:
+ ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+ Returns:
+ A jinja2 template.
+ """
+ return self.read_templates([filename])[0]
+
def read_templates(
- self,
- filenames: List[str],
- custom_template_directory: Optional[str] = None,
- autoescape: bool = False,
+ self, filenames: List[str], custom_template_directory: Optional[str] = None,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
@@ -215,7 +232,8 @@ class Config:
template directory. If `custom_template_directory` is supplied, that directory
is tried first.
- Files read are treated as Jinja templates. These templates are not rendered yet.
+ Files read are treated as Jinja templates. The templates are not rendered yet
+ and have autoescape enabled.
Args:
filenames: A list of template filenames to read.
@@ -223,16 +241,12 @@ class Config:
custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead.
- autoescape: Whether to autoescape variables before inserting them into the
- template.
-
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
Returns:
A list of jinja2 templates.
"""
- templates = []
search_directories = [self.default_template_dir]
# The loader will first look in the custom template directory (if specified) for the
@@ -250,7 +264,7 @@ class Config:
# TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(search_directories)
- env = jinja2.Environment(loader=loader, autoescape=autoescape)
+ env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
# Update the environment with our custom filters
env.filters.update(
@@ -260,12 +274,8 @@ class Config:
}
)
- for filename in filenames:
- # Load the template
- template = env.get_template(filename)
- templates.append(template)
-
- return templates
+ # Load the templates
+ return [env.get_template(filename) for filename in filenames]
class RootConfig:
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3ccea4b02d..70025b5d60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -19,6 +19,7 @@ from synapse.config import (
password_auth_providers,
push,
ratelimiting,
+ redis,
registration,
repository,
room_directory,
@@ -53,7 +54,7 @@ class RootConfig:
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
- ratelimit: ratelimiting.RatelimitConfig
+ ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
@@ -81,6 +82,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
+ redis: redis.RedisConfig
config_classes: List = ...
def __init__(self) -> None: ...
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index cb00958165..9e48f865cc 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -28,9 +28,7 @@ class CaptchaConfig(Config):
"recaptcha_siteverify_api",
"https://www.recaptcha.net/recaptcha/api/siteverify",
)
- self.recaptcha_template = self.read_templates(
- ["recaptcha.html"], autoescape=True
- )[0]
+ self.recaptcha_template = self.read_template("recaptcha.html")
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index c7877b4095..b226890c2a 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -30,7 +30,13 @@ class CasConfig(Config):
if self.cas_enabled:
self.cas_server_url = cas_config["server_url"]
- self.cas_service_url = cas_config["service_url"]
+ public_base_url = cas_config.get("service_url") or self.public_baseurl
+ if public_base_url[-1] != "/":
+ public_base_url += "/"
+ # TODO Update this to a _synapse URL.
+ self.cas_service_url = (
+ public_base_url + "_matrix/client/r0/login/cas/ticket"
+ )
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes") or {}
else:
@@ -53,10 +59,6 @@ class CasConfig(Config):
#
#server_url: "https://cas-server.com"
- # The public URL of the homeserver.
- #
- #service_url: "https://homeserver.domain.com:8448"
-
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 6efa59b110..c47f364b14 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -89,7 +89,7 @@ class ConsentConfig(Config):
def read_config(self, config, **kwargs):
consent_config = config.get("user_consent")
- self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
+ self.terms_template = self.read_template("terms.html")
if consent_config is None:
return
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 784b416f95..bb122ef182 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -53,8 +53,7 @@ class OIDCConfig(Config):
"Multiple OIDC providers have the idp_id %r." % idp_id
)
- public_baseurl = self.public_baseurl
- self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
+ self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback"
@property
def oidc_enabled(self) -> bool:
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..def33a60ad 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig:
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"])
- self.burst_count = config.get("burst_count", defaults["burst_count"])
+ self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
class FederationRateLimitConfig:
@@ -102,6 +102,20 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 3},
)
+ self.rc_3pid_validation = RateLimitConfig(
+ config.get("rc_3pid_validation") or {},
+ defaults={"per_second": 0.003, "burst_count": 5},
+ )
+
+ self.rc_invites_per_room = RateLimitConfig(
+ config.get("rc_invites", {}).get("per_room", {}),
+ defaults={"per_second": 0.3, "burst_count": 10},
+ )
+ self.rc_invites_per_user = RateLimitConfig(
+ config.get("rc_invites", {}).get("per_user", {}),
+ defaults={"per_second": 0.003, "burst_count": 5},
+ )
+
def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
@@ -131,6 +145,9 @@ class RatelimitConfig(Config):
# users are joining rooms the server is already in (this is cheap) vs
# "remote" for when users are trying to join rooms not on the server (which
# can be more expensive)
+ # - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+ # - two for ratelimiting how often invites can be sent in a room or to a
+ # specific user.
#
# The defaults are as shown below.
#
@@ -164,7 +181,18 @@ class RatelimitConfig(Config):
# remote:
# per_second: 0.01
# burst_count: 3
-
+ #
+ #rc_3pid_validation:
+ # per_second: 0.003
+ # burst_count: 5
+ #
+ #rc_invites:
+ # per_room:
+ # per_second: 0.3
+ # burst_count: 10
+ # per_user:
+ # per_second: 0.003
+ # burst_count: 5
# Ratelimiting settings for incoming federation
#
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 4bfc69cb7a..ac48913a0b 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -176,9 +176,7 @@ class RegistrationConfig(Config):
self.session_lifetime = session_lifetime
# The success template used during fallback auth.
- self.fallback_success_template = self.read_templates(
- ["auth_success.html"], autoescape=True
- )[0]
+ self.fallback_success_template = self.read_template("auth_success.html")
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 74b67b230a..14b21796d9 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -125,19 +125,24 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
- def get_options(self, host: bytes):
+ self._should_verify = self._config.federation_verify_certificates
+
+ self._federation_certificate_verification_whitelist = (
+ self._config.federation_certificate_verification_whitelist
+ )
+ def get_options(self, host: bytes):
# IPolicyForHTTPS.get_options takes bytes, but we want to compare
# against the str whitelist. The hostnames in the whitelist are already
# IDNA-encoded like the hosts will be here.
ascii_host = host.decode("ascii")
# Check if certificate verification has been enabled
- should_verify = self._config.federation_verify_certificates
+ should_verify = self._should_verify
# Check if we've disabled certificate verification for this host
- if should_verify:
- for regex in self._config.federation_certificate_verification_whitelist:
+ if self._should_verify:
+ for regex in self._federation_certificate_verification_whitelist:
if regex.match(ascii_host):
should_verify = False
break
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d330ae5dbc..40e1451201 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -810,7 +810,7 @@ class FederationClient(FederationBase):
"User's homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
- elif e.code == 403:
+ elif e.code in (403, 429):
raise e.to_synapse_error()
else:
raise
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
self._wake_destinations_needing_catchup,
)
+ self._external_cache = hs.get_external_cache()
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -197,22 +199,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send():
return
- try:
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = await self.state.get_hosts_in_room_at_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
- except Exception:
- logger.exception(
- "Failed to calculate hosts in room for event: %s",
- event.event_id,
+ destinations = None # type: Optional[Set[str]]
+ if not event.prev_event_ids():
+ # If there are no prev event IDs then the state is empty
+ # and so no remote servers in the room
+ destinations = set()
+ else:
+ # We check the external cache for the destinations, which is
+ # stored per state group.
+
+ sg = await self._external_cache.get(
+ "event_to_prev_state_group", event.event_id
)
- return
+ if sg:
+ destinations = await self._external_cache.get(
+ "get_joined_hosts", str(sg)
+ )
+
+ if destinations is None:
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = await self.state.get_hosts_in_room_at_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
destinations = {
d
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
import twisted
import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
from synapse.app import check_bind_error
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
class AcmeHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
- async def start_listening(self):
+ async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
logger.error(ACME_REGISTER_FAIL_ERROR)
raise
- async def provision_certificate(self):
+ async def provision_certificate(self) -> None:
logger.warning("Reprovisioning %s", self._acme_domain)
@@ -110,5 +114,3 @@ class AcmeHandler:
except Exception:
logger.exception("Failed saving!")
raise
-
- return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
imported conditionally.
"""
import logging
+from typing import Dict, Iterable, List
import attr
+import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
from zope.interface import implementer
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
+from twisted.web.resource import IResource
logger = logging.getLogger(__name__)
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+ reactor: IReactorTCP,
+ acme_url: str,
+ account_key_file: str,
+ well_known_resource: IResource,
+) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
- acme_url (str): URL to use to request certificates
- account_key_file (str): where to store the account key
- well_known_resource (twisted.web.IResource): web resource for .well-known.
+ acme_url: URL to use to request certificates
+ account_key_file: where to store the account key
+ well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""
- certs = attr.ib(default=attr.Factory(dict))
+ certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
- def store(self, server_name, pem_objects):
+ def store(
+ self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+ ) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
- key_file (str): name of the file to use as the account key
+ key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6f746711ca..a19c556437 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -568,16 +568,6 @@ class AuthHandler(BaseHandler):
session.session_id, login_type, result
)
except LoginError as e:
- if login_type == LoginType.EMAIL_IDENTITY:
- # riot used to have a bug where it would request a new
- # validation token (thus sending a new email) each time it
- # got a 401 with a 'flows' field.
- # (https://github.com/vector-im/vector-web/issues/2447).
- #
- # Grandfather in the old behaviour for now to avoid
- # breaking old riot deployments.
- raise
-
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048523ec94..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -100,11 +100,7 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
- return "%s%s?%s" % (
- self._cas_service_url,
- "/_matrix/client/r0/login/cas/ticket",
- urllib.parse.urlencode(args),
- )
+ return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..0863154f7a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
- async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Retrieve the given user's devices
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
- async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
""" Retrieve the given device
Args:
@@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
- device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+ device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -946,8 +946,8 @@ class DeviceListUpdater:
async def process_cross_signing_key_update(
self,
user_id: str,
- master_key: Optional[Dict[str, Any]],
- self_signing_key: Optional[Dict[str, Any]],
+ master_key: Optional[JsonDict],
+ self_signing_key: Optional[JsonDict],
) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
+ JsonDict,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class E2eKeysHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
@@ -78,7 +82,9 @@ class E2eKeysHandler:
)
@trace
- async def query_devices(self, query_body, timeout, from_user_id):
+ async def query_devices(
+ self, query_body: JsonDict, timeout: int, from_user_id: str
+ ) -> JsonDict:
""" Handle a device key query from a client
{
@@ -98,12 +104,14 @@ class E2eKeysHandler:
}
Args:
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries)
# First get local devices.
- failures = {}
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
)
# Now attempt to get any remote devices from our local cache.
- remote_queries_not_in_cache = {}
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
- query_list = []
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
return ret
async def get_cross_signing_keys_from_cache(
- self, query, from_user_id
+ self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
- query (Iterable[string]) an iterable of user IDs. A dict whose keys
+ query: an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for
query_devices can be used here.
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
@@ -315,14 +325,12 @@ class E2eKeysHandler:
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]
- if (
- from_user_id in keys
- and keys[from_user_id] is not None
- and "user_signing" in keys[from_user_id]
- ):
- # users can see other users' master and self-signing keys, but can
- # only see their own user-signing keys
- user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ if from_user_id:
+ from_user_key = keys.get(from_user_id)
+ if from_user_key and "user_signing" in from_user_key:
+ user_signing_keys[from_user_id] = from_user_key["user_signing"]
return {
"master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
- local_query = []
+ local_query = [] # type: List[Tuple[str, Optional[str]]]
- result_dict = {}
+ result_dict = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,14 @@ class E2eKeysHandler:
log_kv(results)
return result_dict
- async def on_federation_query_client_keys(self, query_body):
+ async def on_federation_query_client_keys(
+ self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+ ) -> JsonDict:
""" Handle a device key query from a federated server
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@@ -397,31 +409,34 @@ class E2eKeysHandler:
return ret
@trace
- async def claim_one_time_keys(self, query, timeout):
- local_query = []
- remote_queries = {}
+ async def claim_one_time_keys(
+ self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+ ) -> JsonDict:
+ local_query = [] # type: List[Tuple[str, str, str]]
+ remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
- for user_id, device_keys in query.get("one_time_keys", {}).items():
+ for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- for device_id, algorithm in device_keys.items():
+ for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_keys
+ remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query)
- json_result = {}
- failures = {}
+ # A map of user ID -> device ID -> key ID -> key.
+ json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+ failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json_decoder.decode(json_bytes)
+ key_id: json_decoder.decode(json_str)
}
@trace
@@ -468,7 +483,9 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures}
@tag_args
- async def upload_keys_for_user(self, user_id, device_id, keys):
+ async def upload_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
+ ) -> JsonDict:
time_now = self.clock.time_msec()
@@ -543,8 +560,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
- self, user_id, device_id, time_now, one_time_keys
- ):
+ self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ ) -> None:
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
@@ -585,12 +602,14 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
- async def upload_signing_keys_for_user(self, user_id, keys):
+ async def upload_signing_keys_for_user(
+ self, user_id: str, keys: JsonDict
+ ) -> JsonDict:
"""Upload signing keys for cross-signing
Args:
- user_id (string): the user uploading the keys
- keys (dict[string, dict]): the signing keys
+ user_id: the user uploading the keys
+ keys: the signing keys
"""
# if a master key is uploaded, then check it. Otherwise, load the
@@ -667,16 +686,17 @@ class E2eKeysHandler:
return {}
- async def upload_signatures_for_device_keys(self, user_id, signatures):
+ async def upload_signatures_for_device_keys(
+ self, user_id: str, signatures: JsonDict
+ ) -> JsonDict:
"""Upload device signatures for cross-signing
Args:
- user_id (string): the user uploading the signatures
- signatures (dict[string, dict[string, dict]]): map of users to
- devices to signed keys. This is the submission from the user; an
- exception will be raised if it is malformed.
+ user_id: the user uploading the signatures
+ signatures: map of users to devices to signed keys. This is the submission
+ from the user; an exception will be raised if it is malformed.
Returns:
- dict: response to be sent back to the client. The response will have
+ The response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed.
Raises:
@@ -719,7 +739,9 @@ class E2eKeysHandler:
return {"failures": failures}
- async def _process_self_signatures(self, user_id, signatures):
+ async def _process_self_signatures(
+ self, user_id: str, signatures: JsonDict
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +753,14 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
- reasons
+ A tuple of a list of signatures to store, and a map of users to
+ devices to failure reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -834,19 +855,24 @@ class E2eKeysHandler:
return signature_list, failures
def _check_master_key_signature(
- self, user_id, master_key_id, signed_master_key, stored_master_key, devices
- ):
+ self,
+ user_id: str,
+ master_key_id: str,
+ signed_master_key: JsonDict,
+ stored_master_key: JsonDict,
+ devices: Dict[str, Dict[str, JsonDict]],
+ ) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
Args:
- user_id (string): the user whose master key is being checked
- master_key_id (string): the ID of the user's master key
- signed_master_key (dict): the user's signed master key that was uploaded
- stored_master_key (dict): our previously-stored copy of the user's master key
- devices (iterable(dict)): the user's devices
+ user_id: the user whose master key is being checked
+ master_key_id: the ID of the user's master key
+ signed_master_key: the user's signed master key that was uploaded
+ stored_master_key: our previously-stored copy of the user's master key
+ devices: the user's devices
Returns:
- list[SignatureListItem]: a list of signatures to store
+ A list of signatures to store
Raises:
SynapseError: if a signature is invalid
@@ -877,25 +903,26 @@ class E2eKeysHandler:
return master_key_signature_list
- async def _process_other_signatures(self, user_id, signatures):
+ async def _process_other_signatures(
+ self, user_id: str, signatures: Dict[str, dict]
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
Args:
- user_id (string): the user uploading the keys
- signatures (dict[string, dict]): map of users to devices to signed keys
+ user_id: the user uploading the keys
+ signatures: map of users to devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
+ A list of signatures to store, and a map of users to devices to failure
reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -983,7 +1010,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
- ):
+ ) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1024,7 @@ class E2eKeysHandler:
This affects what signatures are fetched.
Returns:
- dict, str, VerifyKey: the raw key data, the key ID, and the
- signedjson verify key
+ The raw key data, the key ID, and the signedjson verify key
Raises:
NotFoundError: if the key is not found
@@ -1135,16 +1161,18 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+ key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required.
Args:
- key (dict): the key data to verify
- user_id (str): the user whose key is being checked
- key_type (str): the type of key that the key should be
- signing_key (VerifyKey): (optional) the signing key that the key should
- be signed with. If omitted, signatures will not be checked.
+ key: the key data to verify
+ user_id: the user whose key is being checked
+ key_type: the type of key that the key should be
+ signing_key: the signing key that the key should be signed with. If
+ omitted, signatures will not be checked.
"""
if (
key.get("user_id") != user_id
@@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
)
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+ user_id: str,
+ verify_key: VerifyKey,
+ signed_device: JsonDict,
+ stored_device: JsonDict,
+) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
exception if an error is detected.
Args:
- user_id (str): the user ID whose signature is being checked
- verify_key (VerifyKey): the key to verify the device with
- signed_device (dict): the uploaded signed device data
- stored_device (dict): our previously stored copy of the device
+ user_id: the user ID whose signature is being checked
+ verify_key: the key to verify the device with
+ signed_device: the uploaded signed device data
+ stored_device: our previously stored copy of the device
Raises:
SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
return {"status": 503, "message": str(e)}
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
@@ -1239,16 +1272,16 @@ class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys.
"""
- signing_key_id = attr.ib()
- target_user_id = attr.ib()
- target_device_id = attr.ib()
- signature = attr.ib()
+ signing_key_id = attr.ib(type=str)
+ target_user_id = attr.ib(type=str)
+ target_device_id = attr.ib(type=str)
+ signature = attr.ib(type=JsonDict)
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs, e2e_keys_handler):
+ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {}
+ self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
iterable=True,
)
- async def incoming_signing_key_update(self, origin, edu_content):
+ async def incoming_signing_key_update(
+ self, origin: str, edu_content: JsonDict
+ ) -> None:
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
- origin (string): the server that sent the EDU
- edu_content (dict): the contents of the EDU
+ origin: the server that sent the EDU
+ edu_content: the contents of the EDU
"""
user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id)
- async def _handle_signing_key_updates(self, user_id):
+ async def _handle_signing_key_updates(self, user_id: str) -> None:
"""Actually handle pending updates.
Args:
- user_id (string): the user whose updates we are processing
+ user_id: the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
@@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
- device_ids = []
+ device_ids = [] # type: List[str]
logger.info("pending updates: %r", pending_updates)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import (
Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
The actual payload of the encrypted keys is completely opaque to the handler.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
# Used to lock whenever a client is uploading key data. This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
- async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> List[JsonDict]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
Args:
- user_id(str): the user whose keys we're getting
- version(str): the version ID of the backup we're getting keys from
- room_id(string): room ID to get keys for, for None to get keys for all rooms
- session_id(string): session ID to get keys for, for None to get keys for all
+ user_id: the user whose keys we're getting
+ version: the version ID of the backup we're getting keys from
+ room_id: room ID to get keys for, for None to get keys for all rooms
+ session_id: session ID to get keys for, for None to get keys for all
sessions
Raises:
NotFoundError: if the backup version does not exist
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
return results
@trace
- async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_room_keys(
+ self,
+ user_id: str,
+ version: str,
+ room_id: Optional[str] = None,
+ session_id: Optional[str] = None,
+ ) -> JsonDict:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
Args:
- user_id(str): the user whose backup we're deleting
- version(str): the version ID of the backup we're deleting
- room_id(string): room ID to delete keys for, for None to delete keys for all
+ user_id: the user whose backup we're deleting
+ version: the version ID of the backup we're deleting
+ room_id: room ID to delete keys for, for None to delete keys for all
rooms
- session_id(string): session ID to delete keys for, for None to delete keys
+ session_id: session ID to delete keys for, for None to delete keys
for all sessions
Raises:
NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@trace
- async def upload_room_keys(self, user_id, version, room_keys):
+ async def upload_room_keys(
+ self, user_id: str, version: str, room_keys: JsonDict
+ ) -> JsonDict:
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_keys(dict): a nested dict describing the room_keys we're setting:
+ user_id: the user whose backup we're setting
+ version: the version ID of the backup we're updating
+ room_keys: a nested dict describing the room_keys we're setting:
{
"rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@staticmethod
- def _should_replace_room_key(current_room_key, room_key):
+ def _should_replace_room_key(
+ current_room_key: Optional[JsonDict], room_key: JsonDict
+ ) -> bool:
"""
Determine whether to replace a given current_room_key (if any)
with a newly uploaded room_key backup
Args:
- current_room_key (dict): Optional, the current room_key dict if any
- room_key (dict): The new room_key dict which may or may not be fit to
+ current_room_key: Optional, the current room_key dict if any
+ room_key : The new room_key dict which may or may not be fit to
replace the current_room_key
Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
return True
@trace
- async def create_version(self, user_id, version_info):
+ async def create_version(self, user_id: str, version_info: JsonDict) -> str:
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
Args:
- user_id(str): the user whose backup version we're creating
- version_info(dict): metadata about the new version being created
+ user_id: the user whose backup version we're creating
+ version_info: metadata about the new version being created
{
"algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
}
Returns:
- A deferred of a string that gives the new version number.
+ The new version number.
"""
# TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
)
return new_version
- async def get_version_info(self, user_id, version=None):
+ async def get_version_info(
+ self, user_id: str, version: Optional[str] = None
+ ) -> JsonDict:
"""Get the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're querying
- version(str): Optional; if None gives the most recent version
+ user_id: the user whose current backup version we're querying
+ version: Optional; if None gives the most recent version
otherwise a historical one.
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of a info dict that gives the info about the new version.
+ A info dict that gives the info about the new version.
{
"version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
return res
@trace
- async def delete_version(self, user_id, version=None):
+ async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
raise
@trace
- async def update_version(self, user_id, version, version_info):
+ async def update_version(
+ self, user_id: str, version: str, version_info: JsonDict
+ ) -> JsonDict:
"""Update the info about a given version of the user's backup
Args:
- user_id(str): the user whose current backup version we're updating
- version(str): the backup version we're updating
- version_info(dict): the new information about the backup
+ user_id: the user whose current backup version we're updating
+ version: the backup version we're updating
+ version_info: the new information about the backup
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
- A deferred of an empty dict.
+ An empty dict.
"""
if "version" not in version_info:
version_info["version"] = version
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..dbdfd56ff5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler):
if event.state_key == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ member_handler.ratelimit_invite(event.room_id, event.state_key)
+
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
@@ -2093,6 +2097,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
return context
async def _check_for_soft_fail(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index df29edeb83..71f11ef94a 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
class GroupsLocalWorkerHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
- async def get_group_summary(self, group_id, requester_user_id):
+ async def get_group_summary(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
return res
- async def get_users_in_group(self, group_id, requester_user_id):
+ async def get_users_in_group(
+ self, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get users in a group
"""
if self.is_mine_id(group_id):
- res = await self.groups_server_handler.get_users_in_group(
+ return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
- return res
group_server_name = get_domain_from_id(group_id)
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
return res
- async def get_joined_groups(self, user_id):
+ async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
- async def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
# TODO: Verify attestations
return {"groups": result}
- async def bulk_get_publicised_groups(self, user_ids, proxy=True):
- destinations = {}
+ async def bulk_get_publicised_groups(
+ self, user_ids: Iterable[str], proxy: bool = True
+ ) -> JsonDict:
+ destinations = {} # type: Dict[str, Set[str]]
local_users = set()
for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
- failed_results = []
+ failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
- async def create_group(self, group_id, user_id, content):
+ async def create_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Create a group
"""
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = None
remote_attestation = None
else:
- local_attestation = self.attestations.create_attestation(group_id, user_id)
- content["attestation"] = local_attestation
-
- content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
- try:
- res = await self.transport_client.create_group(
- get_domain_from_id(group_id), group_id, user_id, content
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
- except RequestSendFailed:
- raise SynapseError(502, "Failed to contact group server")
-
- remote_attestation = res["attestation"]
- await self.attestations.verify_attestation(
- remote_attestation,
- group_id=group_id,
- user_id=user_id,
- server_name=get_domain_from_id(group_id),
- )
+ raise SynapseError(400, "Unable to create remote groups")
is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def join_group(self, group_id, user_id, content):
+ async def join_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Request to join a group
"""
if self.is_mine_id(group_id):
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def accept_invite(self, group_id, user_id, content):
+ async def accept_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
- async def invite(self, group_id, user_id, requester_user_id, config):
+ async def invite(
+ self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+ ) -> JsonDict:
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def on_invite(self, group_id, user_id, content):
+ async def on_invite(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
@@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def remove_user_from_group(
- self, group_id, user_id, requester_user_id, content
- ):
+ self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Remove a user from a group
"""
if user_id == requester_user_id:
@@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
- async def user_removed_from_group(self, group_id, user_id, content):
+ async def user_removed_from_group(
+ self, group_id: str, user_id: str, content: JsonDict
+ ) -> None:
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index f61844d688..4f7137539b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -27,9 +27,11 @@ from synapse.api.errors import (
HttpResponseException,
SynapseError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
@@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler):
self._web_client_location = hs.config.invite_client_location
+ # Ratelimiters for `/requestToken` endpoints.
+ self._3pid_validation_ratelimiter_ip = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+ self._3pid_validation_ratelimiter_address = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+ burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+ )
+
+ def ratelimit_request_token_requests(
+ self, request: SynapseRequest, medium: str, address: str,
+ ):
+ """Used to ratelimit requests to `/requestToken` by IP and address.
+
+ Args:
+ request: The associated request
+ medium: The type of threepid, e.g. "msisdn" or "email"
+ address: The actual threepid ID, e.g. the phone number or email address
+ """
+
+ self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+ self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..e2a7d567fa 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -432,6 +432,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._external_cache = hs.get_external_cache()
+
async def create_event(
self,
requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
+ await self.cache_joined_hosts_for_event(event)
+
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
+ async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+ """Precalculate the joined hosts at the event, when using Redis, so that
+ external federation senders don't have to recalculate it themselves.
+ """
+
+ if not self._external_cache.is_enabled():
+ return
+
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We always set the state group -> joined hosts cache, even if
+ # we already set it, so that the expiry time is reset.
+
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+
+ if state_entry.state_group:
+ joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
+
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ee27d99135..07b2187eb1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self._invite_burst_count = (
+ hs.config.ratelimiting.rc_invites_per_room.burst_count
+ )
+
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
) -> str:
@@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler):
invite_3pid_list = []
invite_list = []
+ if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+ raise SynapseError(400, "Cannot invite so many users at once")
+
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e001e418f9..d335da6f19 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ self._invites_per_room_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+ )
+ self._invites_per_user_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+ burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+ )
+
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
+ def ratelimit_invite(self, room_id: str, invitee_user_id: str):
+ """Ratelimit invites by room and by target user.
+ """
+ self._invites_per_room_limiter.ratelimit(room_id)
+ self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
async def _local_membership_update(
self,
requester: Requester,
@@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "This room has been blocked on this server")
if effective_membership_state == Membership.INVITE:
+ target_id = target.to_string()
+ if ratelimit:
+ self.ratelimit_invite(room_id, target_id)
+
# block any attempts to invite the server notices mxid
- if target.to_string() == self._server_notices_mxid:
+ if target_id == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite = False
@@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
block_invite = True
if not await self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(), target_id, room_id
):
logger.info("Blocking invite due to spam checker")
block_invite = True
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
import itertools
import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
return historical_room_ids
- async def search(self, user, content, batch=None):
+ async def search(
+ self, user: UserID, content: JsonDict, batch: Optional[str] = None
+ ) -> JsonDict:
"""Performs a full text search for a user.
Args:
- user (UserID)
- content (dict): Search parameters
- batch (str): The next_batch parameter. Used for pagination.
+ user
+ content: Search parameters
+ batch: The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
- historical_room_ids = []
+ historical_room_ids = [] # type: List[str]
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
- room_groups = {} # Holds result of grouping by room, if applicable
- sender_group = {} # Holds result of grouping by sender, if applicable
+ # Holds result of grouping by room, if applicable
+ room_groups = {} # type: Dict[str, JsonDict]
+ # Holds result of grouping by sender, if applicable
+ sender_group = {} # type: Dict[str, JsonDict]
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
- room_events = []
+ room_events = [] # type: List[EventBase]
i = 0
pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
state_results = {}
if include_state:
- rooms = {e.room_id for e in allowed_events}
- for room_id in rooms:
+ for room_id in {e.room_id for e in allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
- state_results.values()
-
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
- for room_id, state in state_results.items():
+ for room_id, state_events in state_results.items():
s[room_id] = await self._event_serializer.serialize_events(
- state, time_now
+ state_events, time_now
)
rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..84af2dde7e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,24 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- self._password_policy_handler = hs.get_password_policy_handler()
async def set_password(
self,
@@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
password_hash: str,
logout_devices: bool,
requester: Optional[Requester] = None,
- ):
+ ) -> None:
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class StateDeltasHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+ async def _get_key_change(
+ self,
+ prev_event_id: Optional[str],
+ event_id: Optional[str],
+ key_name: str,
+ public_value: str,
+ ) -> Optional[bool]:
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..d261d7cd4e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +37,7 @@ class StatsHandler:
Heavily derived from UserDirectoryHandler
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -44,7 +50,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
- self.pos = None
+ self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -56,7 +62,7 @@ class StatsHandler:
# we start populating stats
self.clock.call_later(0, self.notify_new_event)
- def notify_new_event(self):
+ def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.stats_enabled or self._is_processing:
@@ -72,7 +78,7 @@ class StatsHandler:
run_as_background_process("stats.notify_new_event", process)
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
@@ -110,10 +116,10 @@ class StatsHandler:
)
for room_id, fields in room_count.items():
- room_deltas.setdefault(room_id, {}).update(fields)
+ room_deltas.setdefault(room_id, Counter()).update(fields)
for user_id, fields in user_count.items():
- user_deltas.setdefault(user_id, {}).update(fields)
+ user_deltas.setdefault(user_id, Counter()).update(fields)
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +137,20 @@ class StatsHandler:
self.pos = max_pos
- async def _handle_deltas(self, deltas):
+ async def _handle_deltas(
+ self, deltas: Iterable[JsonDict]
+ ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process
Returns:
- tuple[dict[str, Counter], dict[str, counter]]
Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
- room_to_stats_deltas = {}
- user_to_stats_deltas = {}
+ room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
+ user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
- room_to_state_updates = {}
+ room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
for delta in deltas:
typ = delta["type"]
@@ -173,7 +180,7 @@ class StatsHandler:
)
continue
- event_content = {}
+ event_content = {} # type: JsonDict
sender = None
if event_id is not None:
@@ -257,13 +264,13 @@ class StatsHandler:
)
if has_changed_joinedness:
- delta = +1 if membership == Membership.JOIN else -1
+ membership_delta = +1 if membership == Membership.JOIN else -1
user_to_stats_deltas.setdefault(user_id, Counter())[
"joined_rooms"
- ] += delta
+ ] += membership_delta
- room_stats_delta["local_users_in_room"] += delta
+ room_stats_delta["local_users_in_room"] += membership_delta
elif typ == EventTypes.Create:
room_state["is_federatable"] = (
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
import logging
import random
from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
- self._room_serials = {}
+ self._room_serials = {} # type: Dict[str, int]
# map room IDs to sets of users currently typing
- self._room_typing = {}
+ self._room_typing = {} # type: Dict[str, Set[str]]
- self._member_last_federation_poke = {}
+ self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)
- def _reset(self):
+ def _reset(self) -> None:
"""Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
- def _handle_timeouts(self):
+ def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
for member in members:
self._handle_timeout_for_member(now, member)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
- def is_typing(self, member):
+ def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
- async def _push_remote(self, member, typing):
+ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
return
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
+ ) -> None:
"""Should be called whenever we receive updates for typing stream.
"""
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
async def _send_changes_in_typing_to_remotes(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
- ):
+ ) -> None:
"""Process a change in typing of a room from replication, sending EDUs
for any local users.
"""
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False)
- def get_current_token(self):
+ def get_current_token(self) -> int:
return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- self._member_typing_until = {} # clock time we expect to stop
+ # clock time we expect to stop
+ self._member_typing_until = {} # type: Dict[RoomMember, int]
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
- def _handle_timeout_for_member(self, now: int, member: RoomMember):
+ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
super()._handle_timeout_for_member(now, member)
if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
- async def started_typing(self, target_user, requester, room_id, timeout):
+ async def started_typing(
+ self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
if was_present:
# No point sending another notification
- return None
+ return
self._push_update(member=member, typing=True)
- async def stopped_typing(self, target_user, requester, room_id):
+ async def stopped_typing(
+ self, target_user: UserID, requester: Requester, room_id: str
+ ) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
- def user_left_room(self, user, room_id):
+ def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
self._stopped_typing(member)
- def _stopped_typing(self, member):
+ def _stopped_typing(self, member: RoomMember) -> None:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
- return None
+ return
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
self._push_update(member=member, typing=False)
- def _push_update(self, member, typing):
+ def _push_update(self, member: RoomMember, typing: bool) -> None:
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update_local(member=member, typing=typing)
- async def _recv_edu(self, origin, content):
+ async def _recv_edu(self, origin: str, content: JsonDict) -> None:
room_id = content["room_id"]
user_id = content["user_id"]
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
self._push_update_local(member=member, typing=content["typing"])
- def _push_update_local(self, member, typing):
+ def _push_update_local(self, member: RoomMember, typing: bool) -> None:
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
- )
+ ) # type: Optional[Iterable[str]]
if changed_rooms is None:
changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
- ):
+ ) -> None:
# The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
#
self.get_typing_handler = hs.get_typing_handler
- def _make_event_for(self, room_id):
+ def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- async def get_new_events(self, from_key, room_ids, **kwargs):
+ async def get_new_events(
+ self, from_key: int, room_ids: Iterable[str], **kwargs
+ ) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
- def get_current_key(self):
+ def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d4651c8348..8aedf5072e 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
- # If still None then the initial background update hasn't happened yet
- if self.pos is None:
- return None
-
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
@@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
if change: # The user joined
event = await self.store.get_event(event_id, allow_none=True)
+ # It isn't expected for this event to not exist, but we
+ # don't want the entire background process to break.
+ if event is None:
+ continue
+
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..0538350f38 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -791,7 +791,7 @@ def tag_args(func):
@wraps(func)
def _tag_args_inner(*args, **kwargs):
- argspec = inspect.getargspec(func)
+ argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i])
set_tag("args", args[len(argspec.args) :])
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 4d875dcb91..8a6dcff30d 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -267,9 +267,21 @@ class Mailer:
fallback_to_members=True,
)
- summary_text = await self.make_summary_text(
- notifs_by_room, state_by_room, notif_events, user_id, reason
- )
+ if len(notifs_by_room) == 1:
+ # Only one room has new stuff
+ room_id = list(notifs_by_room.keys())[0]
+
+ summary_text = await self.make_summary_text_single_room(
+ room_id,
+ notifs_by_room[room_id],
+ state_by_room[room_id],
+ notif_events,
+ user_id,
+ )
+ else:
+ summary_text = await self.make_summary_text(
+ notifs_by_room, state_by_room, notif_events, reason
+ )
template_vars = {
"user_display_name": user_display_name,
@@ -492,139 +504,178 @@ class Mailer:
if "url" in event.content:
messagevars["image_url"] = event.content["url"]
- async def make_summary_text(
+ async def make_summary_text_single_room(
self,
- notifs_by_room: Dict[str, List[Dict[str, Any]]],
- room_state_ids: Dict[str, StateMap[str]],
+ room_id: str,
+ notifs: List[Dict[str, Any]],
+ room_state_ids: StateMap[str],
notif_events: Dict[str, EventBase],
user_id: str,
- reason: Dict[str, Any],
- ):
- if len(notifs_by_room) == 1:
- # Only one room has new stuff
- room_id = list(notifs_by_room.keys())[0]
+ ) -> str:
+ """
+ Make a summary text for the email when only a single room has notifications.
- # If the room has some kind of name, use it, but we don't
- # want the generated-from-names one here otherwise we'll
- # end up with, "new message from Bob in the Bob room"
- room_name = await calculate_room_name(
- self.store, room_state_ids[room_id], user_id, fallback_to_members=False
- )
+ Args:
+ room_id: The ID of the room.
+ notifs: The notifications for this room.
+ room_state_ids: The state map for the room.
+ notif_events: A map of event ID -> notification event.
+ user_id: The user receiving the notification.
+
+ Returns:
+ The summary text.
+ """
+ # If the room has some kind of name, use it, but we don't
+ # want the generated-from-names one here otherwise we'll
+ # end up with, "new message from Bob in the Bob room"
+ room_name = await calculate_room_name(
+ self.store, room_state_ids, user_id, fallback_to_members=False
+ )
- # See if one of the notifs is an invite event for the user
- invite_event = None
- for n in notifs_by_room[room_id]:
- ev = notif_events[n["event_id"]]
- if ev.type == EventTypes.Member and ev.state_key == user_id:
- if ev.content.get("membership") == Membership.INVITE:
- invite_event = ev
- break
-
- if invite_event:
- inviter_member_event_id = room_state_ids[room_id].get(
- ("m.room.member", invite_event.sender)
- )
- inviter_name = invite_event.sender
- if inviter_member_event_id:
- inviter_member_event = await self.store.get_event(
- inviter_member_event_id, allow_none=True
- )
- if inviter_member_event:
- inviter_name = name_from_member_event(inviter_member_event)
-
- if room_name is None:
- return self.email_subjects.invite_from_person % {
- "person": inviter_name,
- "app": self.app_name,
- }
- else:
- return self.email_subjects.invite_from_person_to_room % {
- "person": inviter_name,
- "room": room_name,
- "app": self.app_name,
- }
+ # See if one of the notifs is an invite event for the user
+ invite_event = None
+ for n in notifs:
+ ev = notif_events[n["event_id"]]
+ if ev.type == EventTypes.Member and ev.state_key == user_id:
+ if ev.content.get("membership") == Membership.INVITE:
+ invite_event = ev
+ break
- sender_name = None
- if len(notifs_by_room[room_id]) == 1:
- # There is just the one notification, so give some detail
- event = notif_events[notifs_by_room[room_id][0]["event_id"]]
- if ("m.room.member", event.sender) in room_state_ids[room_id]:
- state_event_id = room_state_ids[room_id][
- ("m.room.member", event.sender)
- ]
- state_event = await self.store.get_event(state_event_id)
- sender_name = name_from_member_event(state_event)
-
- if sender_name is not None and room_name is not None:
- return self.email_subjects.message_from_person_in_room % {
- "person": sender_name,
- "room": room_name,
- "app": self.app_name,
- }
- elif sender_name is not None:
- return self.email_subjects.message_from_person % {
- "person": sender_name,
- "app": self.app_name,
- }
- else:
- # There's more than one notification for this room, so just
- # say there are several
- if room_name is not None:
- return self.email_subjects.messages_in_room % {
- "room": room_name,
- "app": self.app_name,
- }
- else:
- # If the room doesn't have a name, say who the messages
- # are from explicitly to avoid, "messages in the Bob room"
- sender_ids = list(
- {
- notif_events[n["event_id"]].sender
- for n in notifs_by_room[room_id]
- }
- )
-
- member_events = await self.store.get_events(
- [
- room_state_ids[room_id][("m.room.member", s)]
- for s in sender_ids
- ]
- )
-
- return self.email_subjects.messages_from_person % {
- "person": descriptor_from_member_events(member_events.values()),
- "app": self.app_name,
- }
- else:
- # Stuff's happened in multiple different rooms
+ if invite_event:
+ inviter_member_event_id = room_state_ids.get(
+ ("m.room.member", invite_event.sender)
+ )
+ inviter_name = invite_event.sender
+ if inviter_member_event_id:
+ inviter_member_event = await self.store.get_event(
+ inviter_member_event_id, allow_none=True
+ )
+ if inviter_member_event:
+ inviter_name = name_from_member_event(inviter_member_event)
- # ...but we still refer to the 'reason' room which triggered the mail
- if reason["room_name"] is not None:
- return self.email_subjects.messages_in_room_and_others % {
- "room": reason["room_name"],
+ if room_name is None:
+ return self.email_subjects.invite_from_person % {
+ "person": inviter_name,
"app": self.app_name,
}
- else:
- # If the reason room doesn't have a name, say who the messages
- # are from explicitly to avoid, "messages in the Bob room"
- room_id = reason["room_id"]
-
- sender_ids = list(
- {
- notif_events[n["event_id"]].sender
- for n in notifs_by_room[room_id]
- }
- )
- member_events = await self.store.get_events(
- [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
- )
+ return self.email_subjects.invite_from_person_to_room % {
+ "person": inviter_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
+
+ if len(notifs) == 1:
+ # There is just the one notification, so give some detail
+ sender_name = None
+ event = notif_events[notifs[0]["event_id"]]
+ if ("m.room.member", event.sender) in room_state_ids:
+ state_event_id = room_state_ids[("m.room.member", event.sender)]
+ state_event = await self.store.get_event(state_event_id)
+ sender_name = name_from_member_event(state_event)
+
+ if sender_name is not None and room_name is not None:
+ return self.email_subjects.message_from_person_in_room % {
+ "person": sender_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
+ elif sender_name is not None:
+ return self.email_subjects.message_from_person % {
+ "person": sender_name,
+ "app": self.app_name,
+ }
- return self.email_subjects.messages_from_person_and_others % {
- "person": descriptor_from_member_events(member_events.values()),
+ # The sender is unknown, just use the room name (or ID).
+ return self.email_subjects.messages_in_room % {
+ "room": room_name or room_id,
+ "app": self.app_name,
+ }
+ else:
+ # There's more than one notification for this room, so just
+ # say there are several
+ if room_name is not None:
+ return self.email_subjects.messages_in_room % {
+ "room": room_name,
"app": self.app_name,
}
+ return await self.make_summary_text_from_member_events(
+ room_id, notifs, room_state_ids, notif_events
+ )
+
+ async def make_summary_text(
+ self,
+ notifs_by_room: Dict[str, List[Dict[str, Any]]],
+ room_state_ids: Dict[str, StateMap[str]],
+ notif_events: Dict[str, EventBase],
+ reason: Dict[str, Any],
+ ) -> str:
+ """
+ Make a summary text for the email when multiple rooms have notifications.
+
+ Args:
+ notifs_by_room: A map of room ID to the notifications for that room.
+ room_state_ids: A map of room ID to the state map for that room.
+ notif_events: A map of event ID -> notification event.
+ reason: The reason this notification is being sent.
+
+ Returns:
+ The summary text.
+ """
+ # Stuff's happened in multiple different rooms
+ # ...but we still refer to the 'reason' room which triggered the mail
+ if reason["room_name"] is not None:
+ return self.email_subjects.messages_in_room_and_others % {
+ "room": reason["room_name"],
+ "app": self.app_name,
+ }
+
+ room_id = reason["room_id"]
+ return await self.make_summary_text_from_member_events(
+ room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events
+ )
+
+ async def make_summary_text_from_member_events(
+ self,
+ room_id: str,
+ notifs: List[Dict[str, Any]],
+ room_state_ids: StateMap[str],
+ notif_events: Dict[str, EventBase],
+ ) -> str:
+ """
+ Make a summary text for the email when only a single room has notifications.
+
+ Args:
+ room_id: The ID of the room.
+ notifs: The notifications for this room.
+ room_state_ids: The state map for the room.
+ notif_events: A map of event ID -> notification event.
+
+ Returns:
+ The summary text.
+ """
+ # If the room doesn't have a name, say who the messages
+ # are from explicitly to avoid, "messages in the Bob room"
+ sender_ids = {notif_events[n["event_id"]].sender for n in notifs}
+
+ member_events = await self.store.get_events(
+ [room_state_ids[("m.room.member", s)] for s in sender_ids]
+ )
+
+ # There was a single sender.
+ if len(sender_ids) == 1:
+ return self.email_subjects.messages_from_person % {
+ "person": descriptor_from_member_events(member_events.values()),
+ "app": self.app_name,
+ }
+
+ # There was more than one sender, use the first one and a tweaked template.
+ return self.email_subjects.messages_from_person_and_others % {
+ "person": descriptor_from_member_events(list(member_events.values())[:1]),
+ "app": self.app_name,
+ }
+
def make_room_link(self, room_id: str) -> str:
if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
@@ -668,6 +719,15 @@ class Mailer:
def safe_markup(raw_html: str) -> jinja2.Markup:
+ """
+ Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
+
+ Args
+ raw_html: Unsafe HTML.
+
+ Returns:
+ A Markup object ready to safely use in a Jinja template.
+ """
return jinja2.Markup(
bleach.linkify(
bleach.clean(
@@ -684,8 +744,13 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
def safe_text(raw_text: str) -> jinja2.Markup:
"""
- Process text: treat it as HTML but escape any tags (ie. just escape the
- HTML) then linkify it.
+ Sanitise text (escape any HTML tags), and then linkify any bare URLs.
+
+ Args
+ raw_text: Unsafe text which might include HTML markup.
+
+ Returns:
+ A Markup object ready to safely use in a Jinja template.
"""
return jinja2.Markup(
bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 7e50341d74..04c2c1482c 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -17,7 +17,7 @@ import logging
import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.types import StateMap
@@ -63,7 +63,7 @@ async def calculate_room_name(
m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True
)
- if m_room_name and m_room_name.content and m_room_name.content["name"]:
+ if m_room_name and m_room_name.content and m_room_name.content.get("name"):
return m_room_name.content["name"]
# does it have a canonical alias?
@@ -74,15 +74,11 @@ async def calculate_room_name(
if (
canon_alias
and canon_alias.content
- and canon_alias.content["alias"]
+ and canon_alias.content.get("alias")
and _looks_like_an_alias(canon_alias.content["alias"])
):
return canon_alias.content["alias"]
- # at this point we're going to need to search the state by all state keys
- # for an event type, so rearrange the data structure
- room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
-
if not fallback_to_members:
return None
@@ -94,7 +90,7 @@ async def calculate_room_name(
if (
my_member_event is not None
- and my_member_event.content["membership"] == "invite"
+ and my_member_event.content.get("membership") == Membership.INVITE
):
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = await store.get_event(
@@ -111,6 +107,10 @@ async def calculate_room_name(
else:
return "Room Invite"
+ # at this point we're going to need to search the state by all state keys
+ # for an event type, so rearrange the data structure
+ room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
+
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids:
@@ -120,8 +120,8 @@ async def calculate_room_name(
all_members = [
ev
for ev in member_events.values()
- if ev.content["membership"] == "join"
- or ev.content["membership"] == "invite"
+ if ev.content.get("membership") == Membership.JOIN
+ or ev.content.get("membership") == Membership.INVITE
]
# Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than
@@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
def name_from_member_event(member_event: EventBase) -> str:
- if (
- member_event.content
- and "displayname" in member_event.content
- and member_event.content["displayname"]
- ):
+ if member_event.content and member_event.content.get("displayname"):
return member_event.content["displayname"]
return member_event.state_key
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+set_counter = Counter(
+ "synapse_external_cache_set",
+ "Number of times we set a cache",
+ labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+ "synapse_external_cache_get",
+ "Number of times we get a cache",
+ labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+ """A cache backed by an external Redis. Does nothing if no Redis is
+ configured.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._redis_connection = hs.get_outbound_redis_connection()
+
+ def _get_redis_key(self, cache_name: str, key: str) -> str:
+ return "cache_v1:%s:%s" % (cache_name, key)
+
+ def is_enabled(self) -> bool:
+ """Whether the external cache is used or not.
+
+ It's safe to use the cache when this returns false, the methods will
+ just no-op, but the function is useful to avoid doing unnecessary work.
+ """
+ return self._redis_connection is not None
+
+ async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+ """Add the key/value to the named cache, with the expiry time given.
+ """
+
+ if self._redis_connection is None:
+ return
+
+ set_counter.labels(cache_name).inc()
+
+ # txredisapi requires the value to be string, bytes or numbers, so we
+ # encode stuff in JSON.
+ encoded_value = json_encoder.encode(value)
+
+ logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+ return await make_deferred_yieldable(
+ self._redis_connection.set(
+ self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+ )
+ )
+
+ async def get(self, cache_name: str, key: str) -> Optional[Any]:
+ """Look up a key/value in the named cache.
+ """
+
+ if self._redis_connection is None:
+ return None
+
+ result = await make_deferred_yieldable(
+ self._redis_connection.get(self._get_redis_key(cache_name, key))
+ )
+
+ logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+ get_counter.labels(cache_name, result is not None).inc()
+
+ if not result:
+ return None
+
+ # For some reason the integers get magically converted back to integers
+ if isinstance(result, int):
+ return result
+
+ return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
TypingStream,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
- lazyConnection,
- )
-
- logger.info(
- "Connecting to redis (host=%r port=%r)",
- hs.config.redis_host,
- hs.config.redis_port,
)
# First let's ensure that we have a ReplicationStreamer started.
@@ -299,13 +296,7 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = lazyConnection(
- reactor=hs.get_reactor(),
- host=hs.config.redis_host,
- port=hs.config.redis_port,
- password=hs.config.redis.redis_password,
- reconnect=True,
- )
+ outbound_redis_connection = hs.get_outbound_redis_connection()
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
import txredisapi
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
+ wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.
Attributes:
- handler: The command handler to handle incoming commands.
- stream_name: The *redis* stream name to subscribe to and publish from
- (not anything to do with Synapse replication streams).
- outbound_redis_connection: The connection to redis to use to send
+ synapse_handler: The command handler to handle incoming commands.
+ synapse_stream_name: The *redis* stream name to subscribe to and publish
+ from (not anything to do with Synapse replication streams).
+ synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
- handler = None # type: ReplicationCommandHandler
- stream_name = None # type: str
- outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ synapse_handler = None # type: ReplicationCommandHandler
+ synapse_stream_name = None # type: str
+ synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
- logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
- await make_deferred_yieldable(self.subscribe(self.stream_name))
+ logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+ await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
- self.handler.new_connection(self)
+ self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
- self.handler.send_positions_to_connection(self)
+ self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command
"""
- cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
- self.handler.lost_connection(self)
+ self.synapse_handler.lost_connection(self)
# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable(
- self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+ self.synapse_outbound_redis_connection.publish(
+ self.synapse_stream_name, encoded_string
+ )
+ )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+ """A subclass of RedisFactory that periodically sends pings to ensure that
+ we detect dead connections.
+ """
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = txredisapi.ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: int = 30,
+ convertNumbers: Optional[int] = True,
+ ):
+ super().__init__(
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=poolsize,
+ isLazy=isLazy,
+ handler=handler,
+ charset=charset,
+ password=password,
+ replyTimeout=replyTimeout,
+ convertNumbers=convertNumbers,
)
+ hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+ @wrap_as_background_process("redis_ping")
+ async def _send_ping(self):
+ for connection in self.pool:
+ try:
+ await make_deferred_yieldable(connection.ping())
+ except Exception:
+ logger.warning("Failed to send ping to a redis connection")
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
- super().__init__()
-
- # This sets the password on the RedisFactory base class (as
- # SubscriberFactory constructor doesn't pass it through).
- self.password = hs.config.redis.redis_password
+ super().__init__(
+ hs,
+ uuid="subscriber",
+ dbid=None,
+ poolsize=1,
+ replyTimeout=30,
+ password=hs.config.redis.redis_password,
+ )
- self.handler = hs.get_tcp_replication()
- self.stream_name = hs.hostname
+ self.synapse_handler = hs.get_tcp_replication()
+ self.synapse_stream_name = hs.hostname
- self.outbound_redis_connection = outbound_redis_connection
+ self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
- p = super().buildProtocol(addr) # type: RedisSubscriber
+ p = super().buildProtocol(addr)
+ p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
- p.handler = self.handler
- p.outbound_redis_connection = self.outbound_redis_connection
- p.stream_name = self.stream_name
- p.password = self.password
+ p.synapse_handler = self.synapse_handler
+ p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+ p.synapse_stream_name = self.synapse_stream_name
return p
def lazyConnection(
- reactor,
+ hs: "HomeServer",
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
- charset: str = "utf-8",
password: Optional[str] = None,
- connectTimeout: Optional[int] = None,
- replyTimeout: Optional[int] = None,
- convertNumbers: bool = True,
+ replyTimeout: int = 30,
) -> txredisapi.RedisProtocol:
- """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
- reactor.
+ """Creates a connection to Redis that is lazily set up and reconnects if the
+ connections is lost.
"""
- isLazy = True
- poolsize = 1
-
uuid = "%s:%d" % (host, port)
- factory = txredisapi.RedisFactory(
- uuid,
- dbid,
- poolsize,
- isLazy,
- txredisapi.ConnectionHandler,
- charset,
- password,
- replyTimeout,
- convertNumbers,
+ factory = SynapseRedisFactory(
+ hs,
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=1,
+ isLazy=True,
+ handler=txredisapi.ConnectionHandler,
+ password=password,
+ replyTimeout=replyTimeout,
)
factory.continueTrying = reconnect
- for x in range(poolsize):
- reactor.connectTCP(host, port, factory, connectTimeout)
+
+ reactor = hs.get_reactor()
+ reactor.connectTCP(host, port, factory, 30)
return factory.handler
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index a75c73a142..c9bd4bef20 100644
--- a/synapse/res/templates/sso_auth_bad_user.html
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -12,7 +12,7 @@
<header>
<h1>That doesn't look right</h1>
<p>
- <strong>We were unable to validate your {{ server_name | e }} account</strong>
+ <strong>We were unable to validate your {{ server_name }} account</strong>
via single sign‑on (SSO), because the SSO Identity
Provider returned different details than when you logged in.
</p>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index c70d690035..2099c2f1f8 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -12,7 +12,7 @@
<header>
<h1>Confirm it's you to continue</h1>
<p>
- A client is trying to {{ description | e }}. To confirm this action
+ A client is trying to {{ description }}. To confirm this action
re-authorize your account with single sign-on.
</p>
<p><strong>
@@ -20,8 +20,8 @@
</strong></p>
</header>
<main>
- <a href="{{ redirect_url | e }}" class="primary-button">
- Continue with {{ idp.idp_name | e }}
+ <a href="{{ redirect_url }}" class="primary-button">
+ Continue with {{ idp.idp_name }}
</a>
</main>
</body>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 69b93d65c1..b223ca0f56 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -22,7 +22,7 @@
<header>
<h1>There was an error</h1>
<p>
- <strong id="errormsg">{{ error_description | e }}</strong>
+ <strong id="errormsg">{{ error_description }}</strong>
</p>
<p>
If you are seeing this page after clicking a link sent to you via email,
@@ -35,7 +35,7 @@
</p>
<div id="error_code">
<p><strong>Error code</strong></p>
- <p>{{ error | e }}</p>
+ <p>{{ error }}</p>
</div>
</header>
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 5b38481012..62a640dad2 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -3,22 +3,22 @@
<head>
<meta charset="UTF-8">
<link rel="stylesheet" href="/_matrix/static/client/login/style.css">
- <title>{{server_name | e}} Login</title>
+ <title>{{ server_name }} Login</title>
</head>
<body>
<div id="container">
- <h1 id="title">{{server_name | e}} Login</h1>
+ <h1 id="title">{{ server_name }} Login</h1>
<div class="login_flow">
<p>Choose one of the following identity providers:</p>
<form>
- <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
+ <input type="hidden" name="redirectUrl" value="{{ redirect_url }}">
<ul class="radiobuttons">
{% for p in providers %}
<li>
- <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
- <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
+ <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}">
+ <label for="prov{{ loop.index }}">{{ p.idp_name }}</label>
{% if p.idp_icon %}
- <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
+ <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/>
{% endif %}
</li>
{% endfor %}
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index ce4f573848..d1328a6969 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -12,11 +12,11 @@
<header>
{% if new_user %}
<h1>Your account is now ready</h1>
- <p>You've made your account on {{ server_name | e }}.</p>
+ <p>You've made your account on {{ server_name }}.</p>
{% else %}
<h1>Log in</h1>
{% endif %}
- <p>Continue to confirm you trust <strong>{{ display_url | e }}</strong>.</p>
+ <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p>
</header>
<main>
{% if user_profile.avatar_url %}
@@ -24,13 +24,13 @@
<img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
<div class="profile-details">
{% if user_profile.display_name %}
- <div class="display-name">{{ user_profile.display_name | e }}</div>
+ <div class="display-name">{{ user_profile.display_name }}</div>
{% endif %}
- <div class="user-id">{{ user_id | e }}</div>
+ <div class="user-id">{{ user_id }}</div>
</div>
</div>
{% endif %}
- <a href="{{ redirect_url | e }}" class="primary-button">Continue</a>
+ <a href="{{ redirect_url }}" class="primary-button">Continue</a>
</main>
</body>
</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6f7dc06503..57e0a10837 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
+ ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
MakeRoomAdminRestServlet,
@@ -51,6 +54,7 @@ from synapse.rest.admin.users import (
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
+ ShadowBanRestServlet,
UserAdminServlet,
UserMediaRestServlet,
UserMembershipRestServlet,
@@ -230,6 +234,8 @@ def register_servlets(hs, http_server):
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
+ ShadowBanRestServlet(hs).register(http_server)
+ ForwardExtremitiesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ab7cc9102a..f14915d47e 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -431,7 +431,17 @@ class MakeRoomAdminRestServlet(RestServlet):
if not admin_users:
raise SynapseError(400, "No local admin user in room")
- admin_user_id = admin_users[-1]
+ admin_user_id = None
+
+ for admin_user in reversed(admin_users):
+ if room_state.get((EventTypes.Member, admin_user)):
+ admin_user_id = admin_user
+ break
+
+ if not admin_user_id:
+ raise SynapseError(
+ 400, "No local admin user in room",
+ )
pl_content = power_levels.content
else:
@@ -499,3 +509,60 @@ class MakeRoomAdminRestServlet(RestServlet):
)
return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+ """Allows a server admin to get or clear forward extremities.
+
+ Clearing does not require restarting the server.
+
+ Clear forward extremities:
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+ Get forward_extremities:
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.store = hs.get_datastore()
+
+ async def resolve_room_id(self, room_identifier: str) -> str:
+ """Resolve to a room ID, if necessary."""
+ if RoomID.is_valid(room_identifier):
+ resolved_room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ resolved_room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+ if not resolved_room_id:
+ raise SynapseError(
+ 400, "Unknown room ID or room alias %s" % room_identifier
+ )
+ return resolved_room_id
+
+ async def on_DELETE(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+ return 200, {"deleted": deleted_count}
+
+ async def on_GET(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ extremities = await self.store.get_forward_extremities_for_room(room_id)
+ return 200, {"count": len(extremities), "results": extremities}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f39e3d6d5c..68c3c64a0d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
- if len(users) >= limit:
+ if (start + limit) < total:
ret["next_token"] = str(start + len(users))
return 200, ret
@@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
)
return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+ """An admin API for shadow-banning a user.
+
+ A shadow-banned users receives successful responses to their client-server
+ API requests, but the events are not propagated into rooms.
+
+ Shadow-banning a user should be used as a tool of last resort and may lead
+ to confusing or broken behaviour for the client.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+ {}
+
+ 200 OK
+ {}
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be shadow-banned")
+
+ await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+ return 200, {}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..a84a2fb385 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
@@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
@@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..10e1891174 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 31a41e4a27..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -300,6 +300,7 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@@ -312,6 +313,7 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
+ thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@@ -321,6 +323,7 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
+ self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a632099167..bf3be653aa 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Check whether the URL should be downloaded as oEmbed content instead.
- Params:
+ Args:
url: The URL to check.
Returns:
@@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Request content from an oEmbed endpoint.
- Params:
+ Args:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
@@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource):
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to parse the HTML document into the OG response. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ body: The HTML document, as bytes.
+ media_url: The URI used to download the body.
+ request_encoding: The character encoding of the body, as a string.
+
+ Returns:
+ The OG response as a dictionary.
+ """
# If there's no body, nothing useful is going to be found.
if not body:
return {}
from lxml import etree
+ # Create an HTML parser. If this fails, log and return no metadata.
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body, parser)
- og = _calc_og(tree, media_uri)
+ except LookupError:
+ # blindly consider the encoding as utf-8.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+ except Exception as e:
+ logger.warning("Unable to create HTML parser: %s" % (e,))
+ return {}
+
+ def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(body_attempt, parser)
+ return _calc_og(tree, media_uri)
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ try:
+ return _attempt_calc_og(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
- og = _calc_og(tree, media_uri)
-
- return og
+ return _attempt_calc_og(body.decode("utf-8", "ignore"))
-def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d6880f2e6e..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,7 +16,7 @@
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request
@@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
- if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
- )
-
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
- else:
- logger.info("Couldn't find any generated thumbnails")
- respond_404(request)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ url_cache=media_info["url_cache"],
+ server_name=None,
+ )
async def _select_or_generate_local_thumbnail(
self,
@@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_info["filesystem_id"],
+ url_cache=None,
+ server_name=server_name,
+ )
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: Request,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str] = None,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+ """
if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
)
- file_info = FileInfo(
- server_name=server_name,
- file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+ )
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
@@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int,
desired_method: str,
desired_type: str,
- thumbnail_infos,
- ) -> dict:
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str],
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
d_w = desired_width
d_h = desired_height
- if desired_method.lower() == "crop":
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
+ # Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "crop":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
- if t_method == "crop":
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info["thumbnail_type"]
+ length_quality = info["thumbnail_length"]
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
+ )
if crop_info_list:
- return min(crop_info_list)[-1]
- else:
- return min(crop_info_list2)[-1]
- else:
+ thumbnail_info = min(crop_info_list)[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2)[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
+ # Other thumbnails.
info_list2 = []
+
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "scale":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
- if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+ if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
- elif t_method == "scale":
+ else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
- return min(info_list)[-1]
- else:
- return min(info_list2)[-1]
+ thumbnail_info = min(info_list)[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2)[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ thumbnail_length=thumbnail_info["thumbnail_length"],
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..9bdd3177d7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from txredisapi import RedisProtocol
+
from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler
@@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
+ @cache_in_self
+ def get_external_cache(self) -> ExternalCache:
+ return ExternalCache(self)
+
+ @cache_in_self
+ def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+ if not self.config.redis.redis_enabled:
+ return None
+
+ # We only want to import redis module if we're using it, as we have
+ # `txredisapi` as an optional dependency.
+ from synapse.replication.tcp.redis import lazyConnection
+
+ logger.info(
+ "Connecting to redis (host=%r port=%r) for external cache",
+ self.config.redis_host,
+ self.config.redis_port,
+ )
+
+ return lazyConnection(
+ hs=self,
+ host=self.config.redis_host,
+ port=self.config.redis_port,
+ password=self.config.redis.redis_password,
+ reconnect=True,
+ )
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..3bd9ff8ca0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
+ entry = None
else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
- # XXX: can we update the state cache entry for the new state group? or
- # could we set a flag on resolve_state_groups_for_events to tell it to
- # always make a state group?
+ # Assign the new state group to the cached state entry.
+ #
+ # Note that this can race in that we could generate multiple state
+ # groups for the same state entry, but that is just inefficient
+ # rather than dangerous.
+ if entry and entry.state_group is None:
+ entry.state_group = state_group_before_event
#
# now if it's not a state event, we're done
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index c7220bc778..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -262,6 +262,12 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+ """Similar to `executemany`, except `txn.rowcount` will not be correct
+ afterwards.
+
+ More efficient than `executemany` on PostgreSQL
+ """
+
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
UIAuthStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
+ EventForwardExtremitiesStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
+ txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count)
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about
# when the latest change happened.
- txn.executemany(
+ txn.execute_batch(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..309f1e865b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Optional[Dict[str, dict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ ) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Take a list of one time keys out of the database.
Args:
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
_gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
],
)
- txn.executemany(
+ txn.execute_batch(
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5db7d7aaa8..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -473,8 +473,9 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)
- @staticmethod
+ @classmethod
def _add_chain_cover_index(
+ cls,
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
@@ -614,60 +615,17 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for:
return
- # We now calculate the chain IDs/sequence numbers for the events. We
- # do this by looking at the chain ID and sequence number of any auth
- # event with the same type/state_key and incrementing the sequence
- # number by one. If there was no match or the chain ID/sequence
- # number is already taken we generate a new chain.
- #
- # We need to do this in a topologically sorted order as we want to
- # generate chain IDs/sequence numbers of an event's auth events
- # before the event itself.
- chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
- new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
- for event_id in sorted_topologically(
- events_to_calc_chain_id_for, event_to_auth_chain
- ):
- existing_chain_id = None
- for auth_id in event_to_auth_chain.get(event_id, []):
- if event_to_types.get(event_id) == event_to_types.get(auth_id):
- existing_chain_id = chain_map[auth_id]
- break
-
- new_chain_tuple = None
- if existing_chain_id:
- # We found a chain ID/sequence number candidate, check its
- # not already taken.
- proposed_new_id = existing_chain_id[0]
- proposed_new_seq = existing_chain_id[1] + 1
- if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
- already_allocated = db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_auth_chains",
- keyvalues={
- "chain_id": proposed_new_id,
- "sequence_number": proposed_new_seq,
- },
- retcol="event_id",
- allow_none=True,
- )
- if already_allocated:
- # Mark it as already allocated so we don't need to hit
- # the DB again.
- chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
- else:
- new_chain_tuple = (
- proposed_new_id,
- proposed_new_seq,
- )
-
- if not new_chain_tuple:
- new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
- chains_tuples_allocated.add(new_chain_tuple)
-
- chain_map[event_id] = new_chain_tuple
- new_chain_tuples[event_id] = new_chain_tuple
+ # Allocate chain ID/sequence numbers to each new event.
+ new_chain_tuples = cls._allocate_chain_ids(
+ txn,
+ db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ events_to_calc_chain_id_for,
+ chain_map,
+ )
+ chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn(
txn,
@@ -794,6 +752,137 @@ class PersistEventsStore:
],
)
+ @staticmethod
+ def _allocate_chain_ids(
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ events_to_calc_chain_id_for: Set[str],
+ chain_map: Dict[str, Tuple[int, int]],
+ ) -> Dict[str, Tuple[int, int]]:
+ """Allocates, but does not persist, chain ID/sequence numbers for the
+ events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+ for info on args)
+ """
+
+ # We now calculate the chain IDs/sequence numbers for the events. We do
+ # this by looking at the chain ID and sequence number of any auth event
+ # with the same type/state_key and incrementing the sequence number by
+ # one. If there was no match or the chain ID/sequence number is already
+ # taken we generate a new chain.
+ #
+ # We try to reduce the number of times that we hit the database by
+ # batching up calls, to make this more efficient when persisting large
+ # numbers of state events (e.g. during joins).
+ #
+ # We do this by:
+ # 1. Calculating for each event which auth event will be used to
+ # inherit the chain ID, i.e. converting the auth chain graph to a
+ # tree that we can allocate chains on. We also keep track of which
+ # existing chain IDs have been referenced.
+ # 2. Fetching the max allocated sequence number for each referenced
+ # existing chain ID, generating a map from chain ID to the max
+ # allocated sequence number.
+ # 3. Iterating over the tree and allocating a chain ID/seq no. to the
+ # new event, by incrementing the sequence number from the
+ # referenced event's chain ID/seq no. and checking that the
+ # incremented sequence number hasn't already been allocated (by
+ # looking in the map generated in the previous step). We generate a
+ # new chain if the sequence number has already been allocated.
+ #
+
+ existing_chains = set() # type: Set[int]
+ tree = [] # type: List[Tuple[str, Optional[str]]]
+
+ # We need to do this in a topologically sorted order as we want to
+ # generate chain IDs/sequence numbers of an event's auth events before
+ # the event itself.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ for auth_id in event_to_auth_chain.get(event_id, []):
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
+ existing_chain_id = chain_map.get(auth_id)
+ if existing_chain_id:
+ existing_chains.add(existing_chain_id[0])
+
+ tree.append((event_id, auth_id))
+ break
+ else:
+ tree.append((event_id, None))
+
+ # Fetch the current max sequence number for each existing referenced chain.
+ sql = """
+ SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+ WHERE %s
+ GROUP BY chain_id
+ """
+ clause, args = make_in_list_sql_clause(
+ db_pool.engine, "chain_id", existing_chains
+ )
+ txn.execute(sql % (clause,), args)
+
+ chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
+
+ # Allocate the new events chain ID/sequence numbers.
+ #
+ # To reduce the number of calls to the database we don't allocate a
+ # chain ID number in the loop, instead we use a temporary `object()` for
+ # each new chain ID. Once we've done the loop we generate the necessary
+ # number of new chain IDs in one call, replacing all temporary
+ # objects with real allocated chain IDs.
+
+ unallocated_chain_ids = set() # type: Set[object]
+ new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
+ for event_id, auth_event_id in tree:
+ # If we reference an auth_event_id we fetch the allocated chain ID,
+ # either from the existing `chain_map` or the newly generated
+ # `new_chain_tuples` map.
+ existing_chain_id = None
+ if auth_event_id:
+ existing_chain_id = new_chain_tuples.get(auth_event_id)
+ if not existing_chain_id:
+ existing_chain_id = chain_map[auth_event_id]
+
+ new_chain_tuple = None # type: Optional[Tuple[Any, int]]
+ if existing_chain_id:
+ # We found a chain ID/sequence number candidate, check its
+ # not already taken.
+ proposed_new_id = existing_chain_id[0]
+ proposed_new_seq = existing_chain_id[1] + 1
+
+ if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+ new_chain_tuple = (
+ proposed_new_id,
+ proposed_new_seq,
+ )
+
+ # If we need to start a new chain we allocate a temporary chain ID.
+ if not new_chain_tuple:
+ new_chain_tuple = (object(), 1)
+ unallocated_chain_ids.add(new_chain_tuple[0])
+
+ new_chain_tuples[event_id] = new_chain_tuple
+ chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+ # Generate new chain IDs for all unallocated chain IDs.
+ newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ txn, len(unallocated_chain_ids)
+ )
+
+ # Map from potentially temporary chain ID to real chain ID
+ chain_id_to_allocated_map = dict(
+ zip(unallocated_chain_ids, newly_allocated_chain_ids)
+ ) # type: Dict[Any, int]
+ chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+ return {
+ event_id: (chain_id_to_allocated_map[chain_id], seq)
+ for event_id, (chain_id, seq) in new_chain_tuples.items()
+ }
+
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
- for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, update_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
@@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
- for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, rows_to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..0ac1da9c35
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+ async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+ """Delete any extra forward extremities for a room.
+
+ Invalidates the "get_latest_event_ids_in_room" cache if any forward
+ extremities were deleted.
+
+ Returns count deleted.
+ """
+
+ def delete_forward_extremities_for_room_txn(txn):
+ # First we need to get the event_id to not delete
+ sql = """
+ SELECT event_id FROM event_forward_extremities
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id,))
+ rows = txn.fetchall()
+ try:
+ event_id = rows[0][0]
+ logger.debug(
+ "Found event_id %s as the forward extremity to keep for room %s",
+ event_id,
+ room_id,
+ )
+ except KeyError:
+ msg = "No forward extremity event found for room %s" % room_id
+ logger.warning(msg)
+ raise SynapseError(400, msg)
+
+ # Now delete the extra forward extremities
+ sql = """
+ DELETE FROM event_forward_extremities
+ WHERE event_id != ? AND room_id = ?
+ """
+
+ txn.execute(sql, (event_id, room_id))
+ logger.info(
+ "Deleted %s extra forward extremities for room %s",
+ txn.rowcount,
+ room_id,
+ )
+
+ if txn.rowcount > 0:
+ # Invalidate the cache
+ self._invalidate_cache_and_stream(
+ txn, self.get_latest_event_ids_in_room, (room_id,),
+ )
+
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction(
+ "delete_forward_extremities_for_room",
+ delete_forward_extremities_for_room_txn,
+ )
+
+ async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+ """Get list of forward extremities for a room."""
+
+ def get_forward_extremities_for_room_txn(txn):
+ sql = """
+ SELECT event_id, state_group, depth, received_ts
+ FROM event_forward_extremities
+ INNER JOIN event_to_state_groups USING (event_id)
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ """
+
+ txn.execute(sql, (room_id,))
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+ )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
sql,
(
(time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?"
)
- txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+ txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
+ async def count_daily_e2ee_messages(self):
+ """
+ Returns an estimate of the number of messages sent in the last day.
+
+ If it has been significantly less or more than one day since the last
+ call to this function, it will return None.
+ """
+
+ def _count_messages(txn):
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+ async def count_daily_sent_e2ee_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then thats your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_sent_e2ee_messages", _count_messages
+ )
+
+ async def count_daily_active_e2ee_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.encrypted'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ return await self.db_pool.runInteraction(
+ "count_daily_active_e2ee_rooms", _count
+ )
+
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
)
# Update backward extremeties
- txn.executemany(
+ txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..2687ef3e43 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.db_pool.simple_delete_one_txn(
+ # It is expected that there is exactly one pusher to delete, but
+ # if it isn't there (or there are multiple) delete them all.
+ self.db_pool.simple_delete_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 14c0878d81..8405dd460f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+ """Sets whether a user shadow-banned.
+
+ Args:
+ user: user ID of the user to test
+ shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+ """
+
+ def set_shadow_banned_txn(txn):
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="users",
+ keyvalues={"name": user.to_string()},
+ updatevalues={"shadow_banned": shadow_banned},
+ )
+ # In order for this to apply immediately, clear the cache for this user.
+ tokens = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="access_tokens",
+ keyvalues={"user_id": user.to_string()},
+ retcol="token",
+ )
+ for token in tokens:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (token,)
+ )
+
+ await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@@ -1124,7 +1153,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids
"""
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
+ txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
- INSERT_CLUMP_SIZE = 1000
-
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
"""
- for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(to_update_sql, clump)
+ txn.execute_batch(to_update_sql, to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users:
- cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+ cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
async def search_rooms(
self,
- room_ids: List[str],
+ room_ids: Collection[str],
search_term: str,
keys: List[str],
limit,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..d421d18f8d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
-from collections import Counter
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
+from typing_extensions import Counter
+
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+ async def get_earliest_token_for_stats(
+ self, stats_type: str, id: str
+ ) -> Optional[int]:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
)
async def bulk_update_stats_delta(
- self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+ self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
@@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
async def get_changes_room_total_events_and_bytes(
self, min_pos: int, max_pos: int
- ) -> Dict[str, Dict[str, int]]:
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
@@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
max_pos,
)
- def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+ def get_changes_room_total_events_and_bytes_txn(
+ self, txn, low_pos: int, high_pos: int
+ ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Gets the total_events and total_event_bytes counts for rooms and
senders, in a range of stream_orderings (including backfilled events).
Args:
txn
- low_pos (int): Low stream ordering
- high_pos (int): High stream ordering
+ low_pos: Low stream ordering
+ high_pos: High stream ordering
Returns:
- tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
- room and user deltas for total_events/total_event_bytes in the
+ The room and user deltas for total_events/total_event_bytes in the
format of `stats_id` -> fields
"""
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..7b9729da09 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+ async def update_user_directory_stream_pos(self, stream_id: int) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
logger.info("[purge] removing redundant state groups")
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete),
)
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete),
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
import heapq
import logging
import threading
-from collections import deque
+from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
-from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque() # type: Deque[int]
+
+ # We use this as an ordered set, as we want to efficiently append items,
+ # remove items and get the first item. Since we insert IDs in order, the
+ # insertion ordering will ensure its in the correct ordering.
+ #
+ # The key and values are the same, but we never look at the values.
+ self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -162,7 +167,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - self._step
+ return next(iter(self._unfinished_ids)) - self._step
return self._current
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
db_conn: "LoggingDatabaseConnection",
@@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
self,
db_conn: Connection,
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 1ee61851e4..09b094ded7 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
module = importlib.import_module(module)
provider_class = getattr(module, clz)
- module_config = provider.get("config")
+ # Load the module config. If None, pass an empty dictionary instead
+ module_config = provider.get("config") or {}
try:
provider_config = provider_class.parse_config(module_config)
except jsonschema.ValidationError as e:
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 0b24b89a2e..74503112f5 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
@@ -191,6 +191,97 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_room_ratelimit(self):
+ """Tests that invites from federation in a room are actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ def create_invite_for(local_user):
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": local_user,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ self.get_success(
+ self.handler.on_invite_request(
+ other_server,
+ create_invite_for("@user-%d:test" % (i,)),
+ room_version,
+ )
+ )
+
+ self.get_failure(
+ self.handler.on_invite_request(
+ other_server, create_invite_for("@user-4:test"), room_version,
+ ),
+ exc=LimitExceededError,
+ )
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_user_ratelimit(self):
+ """Tests that invites from federation to a particular user are
+ actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ def create_invite():
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": "@user:test",
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ event = create_invite()
+ self.get_success(
+ self.handler.on_invite_request(other_server, event, event.room_version,)
+ )
+
+ event = create_invite()
+ self.get_failure(
+ self.handler.on_invite_request(other_server, event, event.room_version,),
+ exc=LimitExceededError,
+ )
+
def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 961bf09de9..c4e1e7ed85 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -187,6 +187,36 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_multiple_rooms(self):
+ # We want to test multiple notifications from multiple rooms, so we pause
+ # processing of push while we send messages.
+ self.pusher._pause_processing()
+
+ # Create a simple room with multiple other users
+ rooms = [
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ ]
+
+ for r, other in zip(rooms, self.others):
+ self.helper.invite(
+ room=r, src=self.user_id, tok=self.access_token, targ=other.id
+ )
+ self.helper.join(room=r, user=other.id, tok=other.token)
+
+ # The other users send some messages
+ self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+
+ # Nothing should have happened yet, as we're paused.
+ assert not self.email_attempts
+
+ self.pusher._resume_processing()
+
+ # We should get emailed about those messages
+ self._check_for_mail()
+
def test_encrypted_message(self):
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+ """
+ A fake data store which stores a mapping of state key to event content.
+ (I.e. the state key is used as the event ID.)
+ """
+
+ def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+ """
+ Args:
+ events: A state map to event contents.
+ """
+ self._events = {}
+
+ for i, (event_id, content) in enumerate(events):
+ self._events[event_id] = FrozenEvent(
+ {
+ "event_id": "$event_id",
+ "type": event_id[0],
+ "sender": "@user:test",
+ "state_key": event_id[1],
+ "room_id": "#room:test",
+ "content": content,
+ "origin_server_ts": i,
+ },
+ RoomVersions.V1,
+ )
+
+ async def get_event(
+ self, event_id: StateKey, allow_none: bool = False
+ ) -> Optional[FrozenEvent]:
+ assert allow_none, "Mock not configured for allow_none = False"
+
+ return self._events.get(event_id)
+
+ async def get_events(self, event_ids: Iterable[StateKey]):
+ # This is cheating since it just returns all events.
+ return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+ USER_ID = "@test:test"
+ OTHER_USER_ID = "@user:test"
+
+ def _calculate_room_name(
+ self,
+ events: StateMap[dict],
+ user_id: str = "",
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+ ):
+ # This isn't 100% accurate, but works with MockDataStore.
+ room_state_ids = {k[0]: k[0] for k in events}
+
+ return self.get_success(
+ calculate_room_name(
+ MockDataStore(events),
+ room_state_ids,
+ user_id or self.USER_ID,
+ fallback_to_members,
+ fallback_to_single_member,
+ )
+ )
+
+ def test_name(self):
+ """A room name event should be used."""
+ events = [
+ ((EventTypes.Name, ""), {"name": "test-name"}),
+ ]
+ self.assertEqual("test-name", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.Name, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.Name, ""), {"name": 1})]
+ self.assertEqual(1, self._calculate_room_name(events))
+
+ def test_canonical_alias(self):
+ """An canonical alias should be used."""
+ events = [
+ ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+ ]
+ self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_invite(self):
+ """An invite has special behaviour."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+ ]
+ self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+ # Ensure this logic is skipped if we don't fallback to members.
+ self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+ # Check if the event content has garbage.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ]
+ self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+ # No member event for sender.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ]
+ self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+ def test_no_members(self):
+ """Behaviour of an empty room."""
+ events = []
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ # Note that events with invalid (or missing) membership are ignored.
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+ ]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_no_other_members(self):
+ """Behaviour of a room with no other members in it."""
+ events = [
+ (
+ (EventTypes.Member, self.USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Me"},
+ ),
+ ]
+ self.assertEqual("Me", self._calculate_room_name(events))
+
+ # Check if the event content has no displayname.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("@test:test", self._calculate_room_name(events))
+
+ # 3pid invite, use the other user (who is set as the sender).
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual(
+ "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+ )
+
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+ ]
+ self.assertEqual(
+ "Inviting email address",
+ self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+ )
+
+ def test_one_other_member(self):
+ """Behaviour of a room with a single other member."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ]
+ self.assertEqual("Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+
+ # Check if the event content has no displayname and is an invite.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.INVITE},
+ ),
+ ]
+ self.assertEqual("@user:test", self._calculate_room_name(events))
+
+ def test_other_members(self):
+ """Behaviour of a room with multiple other members."""
+ # Two other members.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+ # Three or more other members.
+ events.append(
+ ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+ )
+ self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"type": "m.room.history_visibility",
"sender": "@user:test",
"state_key": "",
- "room_id": "@room:test",
+ "room_id": "#room:test",
"content": content,
},
RoomVersions.V1,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 3379189785..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
+ # We may have an attempt to connect to redis for the external cache already.
+ self.connect_any_redis_attempts()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
@@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
- self.assertEqual(port, 6379)
+ while clients:
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
- client_to_server_transport = FakeTransport(
- server_protocol, self.reactor, client_protocol
- )
- client_protocol.makeConnection(client_to_server_transport)
-
- server_to_client_transport = FakeTransport(
- client_protocol, self.reactor, server_protocol
- )
- server_protocol.makeConnection(server_to_client_transport)
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
- return client_to_server_transport, server_to_client_transport
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
+
+ # Since we use SET/GET to cache things we can safely no-op them.
+ elif command == b"SET":
+ self.send("OK")
+ elif command == b"GET":
+ self.send(None)
else:
raise Exception("Unknown command")
@@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
+ if obj is None:
+ return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 04599c2fcf..ee05ee60bc 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import make_awaitable
@@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.user1 = self.register_user(
- "user1", "pass1", admin=False, displayname="Name 1"
- )
- self.user2 = self.register_user(
- "user2", "pass2", admin=False, displayname="Name 2"
- )
-
def test_no_auth(self):
"""
Try to list users without authentication.
@@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
+ self._create_users(1)
other_user_token = self.login("user1", "pass1")
channel = self.make_request("GET", self.url, access_token=other_user_token)
@@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
List all users, including deactivated users.
"""
+ self._create_users(2)
+
channel = self.make_request(
"GET",
self.url + "?deactivated=true",
@@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, channel.json_body["total"])
# Check that all fields are available
- for u in channel.json_body["users"]:
- self.assertIn("name", u)
- self.assertIn("is_guest", u)
- self.assertIn("admin", u)
- self.assertIn("user_type", u)
- self.assertIn("deactivated", u)
- self.assertIn("displayname", u)
- self.assertIn("avatar_url", u)
+ self._check_fields(channel.json_body["users"])
def test_search_term(self):
"""Test that searching for a users works correctly"""
@@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Check that users were returned
self.assertTrue("users" in channel.json_body)
+ self._check_fields(channel.json_body["users"])
users = channel.json_body["users"]
# Check that the expected number of users were returned
@@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase):
u = users[0]
self.assertEqual(expected_user_id, u["name"])
+ self._create_users(2)
+
+ user1 = "@user1:test"
+ user2 = "@user2:test"
+
# Perform search tests
- _search_test(self.user1, "er1")
- _search_test(self.user1, "me 1")
+ _search_test(user1, "er1")
+ _search_test(user1, "me 1")
- _search_test(self.user2, "er2")
- _search_test(self.user2, "me 2")
+ _search_test(user2, "er2")
+ _search_test(user2, "me 2")
- _search_test(self.user1, "er1", "user_id")
- _search_test(self.user2, "er2", "user_id")
+ _search_test(user1, "er1", "user_id")
+ _search_test(user2, "er2", "user_id")
# Test case insensitive
- _search_test(self.user1, "ER1")
- _search_test(self.user1, "NAME 1")
+ _search_test(user1, "ER1")
+ _search_test(user1, "NAME 1")
- _search_test(self.user2, "ER2")
- _search_test(self.user2, "NAME 2")
+ _search_test(user2, "ER2")
+ _search_test(user2, "NAME 2")
- _search_test(self.user1, "ER1", "user_id")
- _search_test(self.user2, "ER2", "user_id")
+ _search_test(user1, "ER1", "user_id")
+ _search_test(user2, "ER2", "user_id")
_search_test(None, "foo")
_search_test(None, "bar")
@@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+
+ # negative limit
+ channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from
+ channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid guests
+ channel = self.make_request(
+ "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid deactivated
+ channel = self.make_request(
+ "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ def test_limit(self):
+ """
+ Testing list of users with limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 5)
+ self.assertEqual(channel.json_body["next_token"], "5")
+ self._check_fields(channel.json_body["users"])
+
+ def test_from(self):
+ """
+ Testing list of users with a defined starting point (from)
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["users"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of users with a defined starting point and limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(channel.json_body["next_token"], "15")
+ self.assertEqual(len(channel.json_body["users"]), 10)
+ self._check_fields(channel.json_body["users"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 19)
+ self.assertEqual(channel.json_body["next_token"], "19")
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _check_fields(self, content: JsonDict):
+ """Checks that the expected user attributes are present in content
+ Args:
+ content: List that is checked for content
+ """
+ for u in content:
+ self.assertIn("name", u)
+ self.assertIn("is_guest", u)
+ self.assertIn("admin", u)
+ self.assertIn("user_type", u)
+ self.assertIn("deactivated", u)
+ self.assertIn("displayname", u)
+ self.assertIn("avatar_url", u)
+
+ def _create_users(self, number_users: int):
+ """
+ Create a number of users
+ Args:
+ number_users: Number of users to be created
+ """
+ for i in range(1, number_users + 1):
+ self.register_user(
+ "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+ )
+
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
@@ -2211,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+ self.other_user
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of an user without authentication.
+ """
+ channel = self.make_request("POST", self.url)
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_not_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request("POST", self.url, access_token=other_user_token)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that shadow-banning for a user that is not a local returns a 400
+ """
+ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+ channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+
+ def test_success(self):
+ """
+ Shadow-banning should succeed for an admin.
+ """
+ # The user starts off as not shadow-banned.
+ other_user_token = self.login("user", "pass")
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertFalse(result.shadow_banned)
+
+ channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ # Ensure the user is shadow-banned (and the cache was cleared).
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertTrue(result.shadow_banned)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index e689c3fbea..0ebdf1415b 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -18,6 +18,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.types import UserID
from tests import unittest
@@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
self.store = self.hs.get_datastore()
self.get_success(
- self.store.db_pool.simple_update(
- table="users",
- keyvalues={"name": self.banned_user_id},
- updatevalues={"shadow_banned": True},
- desc="shadow_ban",
- )
+ self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
)
self.other_user_id = self.register_user("otheruser", "pass")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomInviteRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_rooms_ratelimit(self):
+ """Tests that invites in a room are actually rate-limited."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ for i in range(3):
+ self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+ self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_users_ratelimit(self):
+ """Tests that invites to a specific user are actually rate-limited."""
+
+ for i in range(3):
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index cb87b80e33..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,7 +24,7 @@ import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_email(self):
+ """Test that we ratelimit /requestToken for the same email.
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test1@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ def reset(ip):
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret, ip)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ self.email_attempts.clear()
+
+ # We expect to be able to make three requests before getting rate
+ # limited.
+ #
+ # We change IPs to ensure that we're not being ratelimited due to the
+ # same IP
+ reset("127.0.0.1")
+ reset("127.0.0.2")
+ reset("127.0.0.3")
+
+ with self.assertRaises(HttpResponseException) as cm:
+ reset("127.0.0.4")
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
@@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
- def _request_token(self, email, client_secret):
+ def _request_token(self, email, client_secret, ip="127.0.0.1"):
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
+ client_ip=ip,
)
- self.assertEquals(200, channel.code, channel.result)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body["sid"]
@@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_ip(self):
+ """Tests that adding emails is ratelimited by IP
+ """
+
+ # We expect to be able to set three emails before getting ratelimited.
+ self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+ self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+ self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+ with self.assertRaises(HttpResponseException) as cm:
+ self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
"""
@@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
body["next_link"] = next_link
channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
- self.assertEquals(expect_code, channel.code, channel.result)
+
+ if channel.code != expect_code:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body.get("sid")
@@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _add_email(self, request_email, expected_email):
"""Test adding an email to profile
"""
+ previous_email_attempts = len(self.email_attempts)
+
client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)
- self.assertEquals(len(self.email_attempts), 1)
+ self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()
self._validate_token(link)
@@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+ threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+ self.assertIn(expected_email, threepids)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index ae2b32b131..a6c6985173 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config()
config["media_store_path"] = self.media_store_path
- config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
@@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
+ """Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self):
+ """Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
+ def test_invalid_type(self):
+ """An invalid thumbnail type is never available."""
+ self._test_thumbnail("invalid", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+ )
+ def test_no_thumbnail_crop(self):
+ """
+ Override the config to generate only scaled thumbnails, but request a cropped one.
+ """
+ self._test_thumbnail("crop", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+ )
+ def test_no_thumbnail_scale(self):
+ """
+ Override the config to generate only cropped thumbnails, but request a scaled one.
+ """
+ self._test_thumbnail("scale", None, False)
+
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
channel = make_request(
diff --git a/tests/server.py b/tests/server.py
index 5a85d5fe7f..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -47,6 +47,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
+ _ip = attr.ib(type=str, default="127.0.0.1")
_producer = None
@property
@@ -120,7 +121,7 @@ class FakeChannel:
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
- return address.IPv4Address("TCP", "127.0.0.1", 3423)
+ return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self):
return None
@@ -196,6 +197,7 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
+ client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Make a web request using the given method, path and content, and render it
@@ -223,6 +225,9 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request
is finished.
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
channel
"""
@@ -250,7 +255,7 @@ def make_request(
if isinstance(content, str):
content = content.encode("utf8")
- channel = FakeChannel(site, reactor)
+ channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel)
req.content = BytesIO(content)
diff --git a/tests/test_preview.py b/tests/test_preview.py
index c19facc1cb..0c6cbbd921 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase):
html = ""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {})
+
+ def test_invalid_encoding(self):
+ """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
+ html = """
+ <html>
+ <head><title>Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(
+ html, "http://example.com/test.html", "invalid-encoding"
+ )
+ self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
+
+ def test_invalid_encoding2(self):
+ """A body which doesn't match the sent character encoding."""
+ # Note that this contains an invalid UTF-8 sequence in the title.
+ html = b"""
+ <html>
+ <head><title>\xff\xff Foo</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
diff --git a/tests/unittest.py b/tests/unittest.py
index bbd295687c..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase):
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
+ client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase):
custom_headers: (name, value) pairs to add as request headers
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
The FakeChannel object which stores the result of the request.
"""
@@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase):
content_is_form,
await_result,
custom_headers,
+ client_ip,
)
def setup_test_homeserver(self, *args, **kwargs):
diff --git a/tests/utils.py b/tests/utils.py
index 022223cf24..68033d7535 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -157,6 +157,7 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
"saml2_enabled": False,
"default_identity_server": None,
"key_refresh_interval": 24 * 60 * 60 * 1000,
diff --git a/tox.ini b/tox.ini
index 0479186348..9ff70fe312 100644
--- a/tox.ini
+++ b/tox.ini
@@ -26,7 +26,8 @@ deps =
pip>=10 ; python_version >= '3.6'
pip>=10,<21.0 ; python_version < '3.6'
-# directories/files we run the linters on
+# directories/files we run the linters on.
+# if you update this list, make sure to do the same in scripts-dev/lint.sh
lint_targets =
setup.py
synapse
|